You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

context.py 19 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. from .ndarray import cpu, gpu, rcpu, rgpu, DLContext, is_gpu_ctx
  2. import contextlib
  3. import re
  4. class DeviceGroup(object):
  5. def __init__(self, ctxs):
  6. self._contexts = self.parse_contexts(ctxs)
  7. self.get_servers_n_workers()
  8. @classmethod
  9. def parse_contexts(cls, ctxs):
  10. if isinstance(ctxs, DeviceGroup):
  11. return ctxs
  12. if isinstance(ctxs, str):
  13. ctxs = re.split(';|,| +', ctxs.lower())
  14. if not isinstance(ctxs, list):
  15. ctxs = [ctxs]
  16. new_ctxs = []
  17. for c in ctxs:
  18. if isinstance(c, tuple):
  19. c = tuple([ccc for ccc in [cls.str2ctx(cc)
  20. for cc in c] if ccc is not None])
  21. else:
  22. c = cls.str2ctx(c)
  23. if c is not None:
  24. new_ctxs.append(c)
  25. return new_ctxs
  26. @classmethod
  27. def str2ctx(cls, c):
  28. if isinstance(c, str):
  29. c = c.lower().split(':')
  30. assert c[-2] in ('cpu', 'gpu'), 'Context invalid: %s' % c
  31. hostname = 'localhost' if len(c) == 2 else c[0]
  32. idx = int(c[-1])
  33. c = rcpu(hostname, idx) if c[-2] == 'cpu' else rgpu(hostname, idx)
  34. assert isinstance(c, DLContext), 'Context invalid: %s' % c
  35. return c
  36. def index(self, ctx):
  37. return self._contexts.index(ctx)
  38. def __getitem__(self, key):
  39. return self._contexts[key]
  40. def __iter__(self):
  41. return iter(self._contexts)
  42. def __len__(self):
  43. return len(self._contexts)
  44. @property
  45. def worker_num(self):
  46. return len(self._workers)
  47. @property
  48. def server_num(self):
  49. return len(self._servers)
  50. @property
  51. def workers(self):
  52. return self._workers
  53. @property
  54. def servers(self):
  55. return self._servers
  56. def get_servers_n_workers(self):
  57. self._workers = []
  58. self._servers = []
  59. for ctx in self._contexts:
  60. if isinstance(ctx, tuple) or is_gpu_ctx(ctx):
  61. self._workers.append(ctx)
  62. else:
  63. self._servers.append(ctx)
  64. def __repr__(self):
  65. result = 'DeviceGroup('
  66. for c in self._contexts:
  67. result += ('(' + ', '.join([str(cc) for cc in c]) +
  68. '), ') if isinstance(c, tuple) else '%s, ' % c
  69. result += ')'
  70. return result
  71. def __hash__(self):
  72. if not hasattr(self, 'hash'):
  73. self.hash = hash(
  74. tuple(sorted(self._contexts, key=lambda x: x.device_id)))
  75. return self.hash
  76. def __eq__(self, other):
  77. return hash(self) == hash(other)
  78. class ContextStack(object):
  79. def __init__(self):
  80. self._stack = []
  81. def peek(self):
  82. return self._stack[-1] if self._stack else None
  83. def push(self, ctx):
  84. return self._stack.append(ctx)
  85. def pop(self):
  86. self._stack.pop()
  87. _default_ctx_stack = ContextStack()
  88. def get_current_context():
  89. return _default_ctx_stack.peek()
  90. @contextlib.contextmanager
  91. def context(ctx):
  92. try:
  93. ctx = DeviceGroup(ctx)
  94. _default_ctx_stack.push(ctx)
  95. yield ctx
  96. finally:
  97. _default_ctx_stack.pop()
  98. def check_worker(ctx):
  99. # if the context is GPU or is a tuple (which means model parallel),
  100. # we regard it as a worker
  101. return isinstance(ctx, tuple) or is_gpu_ctx(ctx)
  102. def get_launch_config_by_traverse_nodes(node_list, default_ctx):
  103. node_strategy = dict()
  104. devices = set()
  105. for ctx in default_ctx:
  106. if isinstance(ctx, tuple):
  107. devices.update(ctx)
  108. else:
  109. devices.add(ctx)
  110. launchPS = default_ctx.server_num > 0
  111. launchMPI = (not launchPS) and default_ctx.worker_num > 1
  112. nrank = default_ctx.worker_num
  113. for node in node_list:
  114. traverse_dfs(node, node_strategy, devices, nrank)
  115. launchPS = launchPS or any([x == 'PS' for x in node_strategy.values()])
  116. launchMPI = launchMPI or any(
  117. [x == 'AllReduce' for x in node_strategy.values()])
  118. return launchMPI, launchPS, node_strategy, devices
  119. def traverse_dfs(node, node_strategy, devices, nrank):
  120. if node in node_strategy:
  121. return
  122. strategy = None
  123. if node.raw_ctx is not None and node.raw_ctx.server_num > 0 and node.raw_ctx.worker_num > 0:
  124. strategy = 'PS'
  125. elif node.raw_ctx is not None and node.raw_ctx.worker_num > 1:
  126. strategy = 'AllReduce'
  127. node_strategy[node] = strategy
  128. for ctx in node.raw_ctx:
  129. if isinstance(ctx, tuple):
  130. devices.update(ctx)
  131. else:
  132. devices.add(ctx)
  133. local_nrank = nrank if node.raw_ctx is None else node.raw_ctx.worker_num
  134. assert local_nrank in (
  135. 0, nrank), 'Number of workers not consist: (%d, %d).' % (local_nrank, nrank)
  136. for n in node.inputs:
  137. traverse_dfs(n, node_strategy, devices, nrank)
  138. def assign_context_by_traverse_nodes(node_list, ctx, mpi_comm, p2p_stream):
  139. from .dataloader import DataloaderOp
  140. from .optimizer import OptimizerOp
  141. from .gpu_ops.PipelineSend import pipeline_send_op
  142. from .gpu_ops.PipelineReceive import pipeline_receive_op
  143. from .gpu_ops.Variable import PlaceholderOp
  144. from .gpu_ops.Dispatch import DispatchOp, DispatchGradientOp
  145. from .gpu_ops.Concat import concat_op
  146. from .gpu_ops.Split import split_op
  147. from .gpu_ops.AddElewise import add_op
  148. def receive_model_parallel(prev_input, node):
  149. # assert dp_index_map[prev_input] < 0 and dp_index_map[node] >= 0
  150. dev_pos = dp_index_map[node]
  151. if isinstance(node.raw_ctx.workers[dev_pos], tuple):
  152. # here we receive from a node on one device dispatching to many
  153. # in this case current node MUST have mp_index, and the split will be handled in sending
  154. assert mp_index_map[node] >= 0, 'Now only support 1 to N.'
  155. hostname = prev_input.raw_ctx.workers[dev_pos].hostname
  156. target_id = prev_input.raw_ctx.workers[dev_pos].device_id
  157. if prev_input not in recv_src:
  158. recv_src[prev_input] = pipeline_receive_op(mpi_comm.getRankFromDevice(
  159. hostname, target_id), mpi_comm, stream=p2p_stream, ctx=ctx)
  160. return recv_src[prev_input]
  161. else:
  162. # here we receive from a node on multiple devices
  163. # in this case current node MUST NOT have mp_index, and handle the combination
  164. target = node_tar_states_map[prev_input]
  165. assert mp_index_map[node] < 0 and (target is None or all(
  166. [ts == 1 for ts in target])), 'Now only support N to 1.'
  167. if prev_input not in recv_src:
  168. device_index = -1
  169. def make_comb(devices, cur_states, depth):
  170. if depth == len(cur_states):
  171. nonlocal device_index
  172. device_index += 1
  173. return pipeline_receive_op(mpi_comm.getRankFromDevice(devices[device_index].hostname, devices[device_index].device_id), mpi_comm, stream=p2p_stream, ctx=ctx)
  174. else:
  175. result = make_comb(devices, cur_states, depth + 1)
  176. for _ in range(1, cur_states[depth]):
  177. result = concat_op(result, make_comb(
  178. devices, cur_states, depth + 1), axis=depth, ctx=ctx)
  179. return result
  180. res = make_comb(
  181. prev_input.raw_ctx.workers[dev_pos], node_cur_states_map[prev_input], 0)
  182. for _ in range(1, node_cur_duplicate_map.get(prev_input, 1)):
  183. res = add_op(res, make_comb(
  184. prev_input.raw_ctx.workers[dev_pos], node_cur_states_map[prev_input], 0), ctx=ctx)
  185. assert device_index + \
  186. 1 == len(prev_input.raw_ctx.workers[dev_pos])
  187. recv_src[prev_input] = res
  188. return recv_src[prev_input]
  189. def send_model_parallel(prev_input, node):
  190. # assert dp_index_map[prev_input] >= 0 and dp_index_map[node] < 0
  191. dev_pos = dp_index_map[prev_input]
  192. if not isinstance(prev_input.raw_ctx.workers[dev_pos], tuple):
  193. # here we send from a node on one device dispatching to many nodes
  194. # in this case current node MUST have mp_index, and the split will be handled in sending
  195. assert mp_index_map[prev_input] < 0, 'Now only support 1 to N.'
  196. device_index = 0
  197. def make_split(devices, target_states, cur_states, depth):
  198. if len(target_states) == depth:
  199. nonlocal device_index
  200. hostname = devices[device_index].hostname
  201. target_id = devices[device_index].device_id
  202. device_index += 1
  203. key = (prev_input, target_id)
  204. if key not in send_dst:
  205. cur_node = prev_input if all([x == 1 for x in target_states]) else split_op(
  206. prev_input, list(range(len(target_states))), list(cur_states), list(target_states), ctx=ctx)
  207. target_rank = mpi_comm.getRankFromDevice(
  208. hostname, target_id)
  209. send_dst[key] = pipeline_send_op(
  210. cur_node, target_rank, mpi_comm, stream=p2p_stream, ctx=ctx)
  211. my_eval_nodes.append(send_dst[key])
  212. else:
  213. for ts in range(target_states[depth]):
  214. cur_states[depth] = ts
  215. make_split(devices, target_states,
  216. cur_states, depth + 1)
  217. for _ in range(node_tar_duplicate_map.get(prev_input, 1)):
  218. cur_states = [0 for _ in range(
  219. len(node_tar_states_map[prev_input]))]
  220. make_split(
  221. node.raw_ctx.workers[dev_pos], node_tar_states_map[prev_input], cur_states, 0)
  222. assert device_index == len(node.raw_ctx.workers[dev_pos])
  223. else:
  224. # here we send from a node on multiple devices to one node
  225. # in this case current node MUST NOT have mp_index, and the combination will be handled in receiving
  226. target = node_tar_states_map[prev_input]
  227. assert mp_index_map[prev_input] >= 0 and (target is None or all(
  228. [ts == 1 for ts in target])), 'Now only support N to 1.'
  229. hostname = node.raw_ctx.workers[dev_pos].hostname
  230. target_id = node.raw_ctx.workers[dev_pos].device_id
  231. key = (prev_input, target_id)
  232. if key not in send_dst:
  233. send_dst[key] = pipeline_send_op(prev_input, mpi_comm.getRankFromDevice(
  234. hostname, target_id), mpi_comm, stream=p2p_stream, ctx=ctx)
  235. my_eval_nodes.append(send_dst[key])
  236. def assign_ctx(node):
  237. if node in dp_index_map:
  238. return
  239. mp_index_map[node] = -1
  240. dp_index_map[node] = -1
  241. if isinstance(node, DataloaderOp):
  242. return
  243. elif isinstance(node, OptimizerOp):
  244. nonlocal opt
  245. assert opt is None, 'Multiple optimizer is invalid.'
  246. opt = node
  247. for n in node.inputs:
  248. assign_ctx(n)
  249. grads = []
  250. original_params = node.optimizer.params
  251. for ind, param in enumerate(original_params):
  252. ori_grad = node.inputs[ind]
  253. if param in trainable_params:
  254. new_grad = receive_model_parallel(ori_grad.inputs[0], param) if isinstance(
  255. ori_grad, (DispatchOp, DispatchGradientOp)) else ori_grad
  256. grads.append(new_grad)
  257. elif isinstance(ori_grad, (DispatchOp, DispatchGradientOp)):
  258. real_input = ori_grad.inputs[0]
  259. my_pos = dp_index_map[real_input]
  260. if my_pos >= 0:
  261. send_model_parallel(ori_grad.inputs[0], param)
  262. if trainable_params:
  263. # indices = [original_params.index(param) for param in trainable_params]
  264. node.optimizer.params = trainable_params
  265. # grads = [node.inputs[index] for index in indices]
  266. node.inputs = grads
  267. node.ctx = ctx
  268. my_eval_nodes.append(node)
  269. elif isinstance(node, DispatchOp):
  270. real_node = node.inputs[0]
  271. assign_ctx(real_node)
  272. node_tar_states_map[real_node] = node.parts
  273. node_tar_duplicate_map[real_node] = node.duplicate
  274. elif isinstance(node, DispatchGradientOp):
  275. real_node = node.inputs[0]
  276. assign_ctx(real_node)
  277. assign_ctx(node.inputs[1])
  278. node_tar_states_map[real_node] = node_cur_states_map.get(
  279. node.inputs[1], None)
  280. node_tar_duplicate_map[real_node] = node_cur_duplicate_map.get(
  281. node.inputs[1], 1)
  282. else:
  283. # now we only support SAME model parallel in data parallel
  284. # and 1 context can only appear once
  285. mp_index = -1
  286. dp_index = -1
  287. for i, c in enumerate(node.raw_ctx.workers):
  288. if isinstance(c, tuple) and ctx in c:
  289. mp_index = c.index(ctx)
  290. dp_index = i
  291. elif ctx == c:
  292. dp_index = i
  293. mp_index_map[node] = mp_index
  294. dp_index_map[node] = dp_index
  295. need_states_deduction = False
  296. for i, n in enumerate(node.inputs):
  297. if isinstance(n, DataloaderOp):
  298. if dp_index >= 0 and n in node_list and n not in my_eval_nodes:
  299. my_eval_nodes.append(n)
  300. continue
  301. assign_ctx(n)
  302. # we assume that in model parallel + data parallel mode,
  303. # devices number of each stage is equal
  304. # the device in correspondent place will communicate with each other
  305. # TODO: not support following case: context(1,5) -> context(5,1); context(1,5) -> context(3,1)
  306. # solution: modify following is_my_node logic to support
  307. # TODO: not support the case that each process has different group init numbers, since there is an AllGather in mpi_nccl_comm's init
  308. # solution: modify mpi_nccl_comm class, so that the MPI part only process once while nccl has several groups
  309. assert node.raw_ctx.worker_num == n.raw_ctx.worker_num, \
  310. 'In pipeline + data parallel, devices number of each stage should be equal!'
  311. if isinstance(n, (DispatchOp, DispatchGradientOp)):
  312. need_states_deduction = True
  313. # here we only allow pipeline + model parallel, which means the devices are all different
  314. # TODO: release the constraint above
  315. # here in every context each device appear only once
  316. # TODO: consider whether or not release the constraint above?
  317. # here we only allow one2n/n2one/n2n, can not change from x to y where x != 1 and y != 1 and x != y in dimension-granularity
  318. # TODO: consider whether or not release the constraint above? too complex and not realistic!
  319. real_input = n.inputs[0]
  320. if dp_index >= 0 and dp_index_map[real_input] < 0:
  321. node.inputs[i] = receive_model_parallel(
  322. real_input, node)
  323. elif dp_index < 0 and dp_index_map[real_input] >= 0:
  324. send_model_parallel(real_input, node)
  325. else:
  326. assert mp_index < 0 and mp_index_map[n] < 0
  327. # handle receiving
  328. if dp_index >= 0 and dp_index != dp_index_map[n]:
  329. my_pos = dp_index
  330. hostname = n.raw_ctx.workers[my_pos].hostname
  331. target_id = n.raw_ctx.workers[my_pos].device_id
  332. if n not in recv_src:
  333. recv_src[n] = pipeline_receive_op(mpi_comm.getRankFromDevice(
  334. hostname, target_id), mpi_comm, stream=p2p_stream, ctx=ctx)
  335. node.inputs[i] = recv_src[n]
  336. # handle sending
  337. if dp_index_map[n] >= 0 and dp_index != dp_index_map[n]:
  338. my_pos = dp_index_map[n]
  339. hostname = node.raw_ctx.workers[my_pos].hostname
  340. target_id = node.raw_ctx.workers[my_pos].device_id
  341. key = (n, target_id)
  342. if key not in send_dst:
  343. send_dst[key] = pipeline_send_op(n, mpi_comm.getRankFromDevice(
  344. hostname, target_id), mpi_comm, stream=p2p_stream, ctx=ctx)
  345. my_eval_nodes.append(send_dst[key])
  346. if dp_index >= 0:
  347. node.ctx = ctx
  348. if node in node_list:
  349. my_eval_nodes.append(node)
  350. if isinstance(node, PlaceholderOp) and node.trainable:
  351. trainable_params.append(node)
  352. if need_states_deduction:
  353. input_states = []
  354. input_duplicates = []
  355. for n in node.inputs:
  356. if isinstance(n, (DispatchOp, DispatchGradientOp)):
  357. input_states.append(node_tar_states_map[n.inputs[0]])
  358. input_duplicates.append(
  359. node_tar_duplicate_map[n.inputs[0]])
  360. else:
  361. input_states.append(node_cur_states_map.get(n, None))
  362. input_duplicates.append(
  363. node_cur_duplicate_map.get(n, 1))
  364. node_cur_states_map[node], node_cur_duplicate_map[node] = node.deduce_states(
  365. input_states, input_duplicates)
  366. opt = None
  367. trainable_params = []
  368. send_dst = {}
  369. recv_src = {}
  370. mp_index_map = {} # model parallel index
  371. dp_index_map = {} # data parallel index
  372. node_cur_duplicate_map = {} # save nodes' duplicate information
  373. node_tar_duplicate_map = {} # save nodes' target states
  374. node_cur_states_map = {} # save nodes' current states
  375. node_tar_states_map = {} # save nodes' target states
  376. my_eval_nodes = []
  377. for node in node_list:
  378. assign_ctx(node)
  379. has_send_recv = send_dst != {} or recv_src != {}
  380. return my_eval_nodes, trainable_params, has_send_recv