You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

executor.py 37 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921
  1. """ library to take autodiff and execute a computation graph """
  2. from __future__ import absolute_import
  3. from .BatchNorm import Batch_NormalizationOp
  4. import numpy as np
  5. from scipy.sparse import spmatrix, coo_matrix
  6. from .. import ndarray
  7. from .._base import DNNL_LIB
  8. from ..cpu_links import array_set as cpu_array_set
  9. from .Variable import PlaceholderOp # add for optimizer
  10. from ..dataloader import DataloaderOp, GNNDataLoaderOp
  11. from .AllReduceCommunicate import AllReduceCommunicateOp
  12. from .ParameterServerCommunicate import ParameterServerCommunicateOp, ParameterServerSparsePullOp, parameterServerSparsePull_op
  13. from .AddElewise import add_op
  14. from .DataTransfer import DataH2DOp, DataD2HOp, DataD2HSparseOp
  15. from .EmbeddingLookUp import EmbeddingLookUp, EmbeddingLookUp_Gradient
  16. from ..optimizer import OptimizerOp
  17. from . import OnesLike
  18. from ..stream import create_stream_handle, Event
  19. from ..context import get_current_context, get_launch_config_by_traverse_nodes, assign_context_by_traverse_nodes, DeviceGroup
  20. from .PipelineSend import PipelineSendOp
  21. from .PipelineReceive import PipelineReceiveOp
  22. from .Dropout import DropoutOp
  23. from .LayerNorm import Layer_NormalizationOp
  24. from operator import add
  25. from functools import reduce
  26. import ctypes
  27. import os
  28. from time import time
  29. def path_to_lib(name):
  30. curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
  31. lib_path = os.path.join(curr_path, '../../../build/lib/')
  32. return os.path.join(lib_path, name)
  33. def wrapped_mpi_nccl_init(init_nccl=True, devices=None):
  34. from ..communicator.mpi_nccl_comm import mpi_communicator
  35. global mpi_comm
  36. global nccl_comm
  37. if 'mpi_comm' not in globals():
  38. mpi_comm = mpi_communicator(devices=devices)
  39. if 'nccl_comm' not in globals():
  40. nccl_comm = mpi_comm.ncclInit() if init_nccl else None
  41. return nccl_comm
  42. def new_group_comm(devices_context=None):
  43. assert 'mpi_comm' in globals()
  44. global mpi_comm
  45. if devices_context is None:
  46. comm = mpi_comm.ncclInit()
  47. else:
  48. comm = mpi_comm.ncclGroupInit(devices_context)
  49. return comm
  50. def get_nccl_communicate():
  51. global nccl_comm
  52. return nccl_comm
  53. def get_worker_communicate():
  54. global ps_comm
  55. return ps_comm
  56. def worker_init():
  57. global ps_comm
  58. ll = ctypes.cdll.LoadLibrary
  59. ps_comm = ll(path_to_lib("libps.so"))
  60. ps_comm.Init()
  61. def worker_finish():
  62. ps_comm.Finalize()
  63. def server_init():
  64. global ps_comm
  65. ll = ctypes.cdll.LoadLibrary
  66. ps_comm = ll(path_to_lib("libps.so"))
  67. ps_comm.Init()
  68. ps_comm.StartServer()
  69. def server_finish():
  70. ps_comm.Finalize()
  71. def scheduler_init():
  72. global ps_comm
  73. ll = ctypes.cdll.LoadLibrary
  74. ps_comm = ll(path_to_lib("libps.so"))
  75. ps_comm.Init()
  76. def scheduler_finish():
  77. ps_comm.Finalize()
  78. class HetuConfig(object):
  79. __slots__ = [
  80. 'eval_node_list',
  81. 'train_name',
  82. 'val_name',
  83. 'context',
  84. 'seed',
  85. 'np_rand',
  86. 'comm_mode',
  87. 'node_strategy',
  88. 'context_launch',
  89. 'ps_comm',
  90. 'nccl_comm',
  91. 'local_rank',
  92. 'rank',
  93. 'nrank',
  94. 'p2p_stream',
  95. 'comp_stream',
  96. 'nccl_stream',
  97. 'h2d_stream',
  98. 'd2h_stream',
  99. 'h2d_ops',
  100. 'd2h_ops',
  101. 'ps_map',
  102. 'infer_ps_map',
  103. 'dataloader_ops',
  104. 'use_sparse_pull',
  105. 'cstable_policy',
  106. 'inference',
  107. 'enable_lazy',
  108. 'bsp',
  109. 'prefetch',
  110. 'cache_bound',
  111. 'log_path',
  112. 'my_eval_nodes',
  113. 'param_allreduce_group',
  114. 'placeholder_to_arr_map'
  115. ]
  116. def __init__(
  117. self,
  118. eval_node_list,
  119. train_name,
  120. val_name,
  121. ctx=None,
  122. seed=None,
  123. comm_mode=None,
  124. use_sparse_pull=True,
  125. cstable_policy=None,
  126. bsp=False,
  127. prefetch=True,
  128. enable_lazy=True,
  129. cache_bound=100,
  130. log_path=None,
  131. ):
  132. '''
  133. context: default device context
  134. comm_mode: communication mode, should be one of the following
  135. None -> Single GPU
  136. PS -> Parameter Server
  137. AllRedeuce -> MPI AllReduce
  138. Hybrid -> Parameter Server for Sparse Parameter and MPI AllReduce for Dense Parameter
  139. '''
  140. self.eval_node_list = eval_node_list
  141. self.train_name = train_name
  142. self.val_name = val_name
  143. # check context
  144. if ctx is None:
  145. ctx = get_current_context()
  146. assert ctx, 'Default context should be determined.'
  147. self.comm_mode = comm_mode
  148. self.node_strategy = {}
  149. local_gpu_devices = None
  150. context_launch = isinstance(ctx, DeviceGroup)
  151. self.context_launch = context_launch
  152. if context_launch:
  153. # with context usage
  154. launchMPI, launchPS, self.node_strategy, devices = get_launch_config_by_traverse_nodes(
  155. eval_node_list, ctx)
  156. local_gpu_devices = sorted(
  157. [dev.device_id for dev in devices if dev.local and ndarray.is_gpu_ctx(dev)])
  158. if not launchMPI and not launchPS:
  159. self.comm_mode = None
  160. elif launchMPI and not launchPS:
  161. self.comm_mode = 'AllReduce'
  162. elif not launchMPI and launchPS:
  163. self.comm_mode = 'PS'
  164. else:
  165. self.comm_mode = 'Hybrid'
  166. # in pipeline or model parallel we have to initialize another p2p stream
  167. init_p2p_stream = len(devices) != len(ctx)
  168. # variables initialization
  169. self.seed = seed if seed else np.int64(time())
  170. self.np_rand = np.random.RandomState(self.seed)
  171. # get attribute of communication mode
  172. self.ps_comm = None
  173. self.nccl_comm = None
  174. self.local_rank = None
  175. self.rank = None
  176. self.nrank = None
  177. ps_nrank = None
  178. if self.comm_mode == 'PS' or self.comm_mode == 'Hybrid':
  179. worker_init()
  180. self.ps_comm = get_worker_communicate()
  181. ps_rank = int(self.ps_comm.rank())
  182. ps_nrank = int(
  183. os.environ['DMLC_NUM_WORKER']) if 'DMLC_NUM_WORKER' in os.environ else 1
  184. if self.comm_mode == "Hybrid" or self.comm_mode == "AllReduce":
  185. self.nccl_comm = wrapped_mpi_nccl_init(devices=local_gpu_devices)
  186. elif context_launch:
  187. self.nccl_comm = wrapped_mpi_nccl_init(
  188. init_nccl=init_p2p_stream, devices=local_gpu_devices)
  189. if self.nccl_comm is not None:
  190. self.local_rank = self.nccl_comm.local_rank
  191. device_id = self.nccl_comm.dev_id
  192. self.rank = self.nccl_comm.rank
  193. self.nrank = self.nccl_comm.nrank
  194. if ps_nrank:
  195. assert ps_nrank == self.nrank
  196. elif self.comm_mode == 'PS':
  197. self.rank = ps_rank
  198. self.nrank = ps_nrank
  199. if context_launch:
  200. global mpi_comm
  201. self.local_rank = mpi_comm.local_rank
  202. device_id = mpi_comm.dev_id
  203. self.my_eval_nodes = eval_node_list
  204. self.p2p_stream = None
  205. self.param_allreduce_group = {}
  206. if context_launch:
  207. # comm_mode is None <=> only 1 model parallel instance
  208. self.context = ndarray.gpu(device_id)
  209. self.p2p_stream = create_stream_handle(
  210. self.context) if init_p2p_stream else None
  211. self.my_eval_nodes, trainable_params, has_send_recv = assign_context_by_traverse_nodes(
  212. eval_node_list, self.context, self.nccl_comm, self.p2p_stream)
  213. if (self.comm_mode == "Hybrid" or self.comm_mode == "AllReduce") and has_send_recv:
  214. # here we need to use group communicator to implement allreduce,
  215. # since not all processes use the same group
  216. groups = set([n.raw_ctx for n in trainable_params])
  217. temp_group_comms = {}
  218. for group in groups:
  219. temp_group_comms[group] = new_group_comm(group)
  220. self.param_allreduce_group = {
  221. n: temp_group_comms[n.raw_ctx] for n in trainable_params}
  222. else:
  223. self.context = ctx
  224. on_gpu = ndarray.is_gpu_ctx(self.context)
  225. self.nccl_stream = None
  226. if self.comm_mode == "Hybrid" or self.comm_mode == "AllReduce":
  227. if on_gpu:
  228. self.nccl_stream = create_stream_handle(self.context)
  229. self.nccl_comm = get_nccl_communicate()
  230. # define streams
  231. self.comp_stream = create_stream_handle(
  232. self.context) if on_gpu else None
  233. self.h2d_stream = create_stream_handle(
  234. self.context) if on_gpu else None
  235. self.d2h_stream = create_stream_handle(
  236. self.context) if on_gpu else None
  237. self.use_sparse_pull = use_sparse_pull if self.comm_mode == 'PS' or self.comm_mode == "Hybrid" else False
  238. self.cstable_policy = cstable_policy if self.comm_mode == 'PS' or self.comm_mode == "Hybrid" else None
  239. self.prefetch = prefetch if self.comm_mode == 'PS' or self.comm_mode == 'Hybrid' else False
  240. if self.cstable_policy is not None:
  241. self.cstable_policy = self.cstable_policy.upper()
  242. self.use_sparse_pull = False
  243. self.h2d_ops = {}
  244. self.d2h_ops = {}
  245. self.ps_map = {}
  246. self.infer_ps_map = {}
  247. self.enable_lazy = False and enable_lazy # now we don't use lazy
  248. self.bsp = bsp
  249. self.cache_bound = int(cache_bound)
  250. self.log_path = log_path
  251. if log_path is not None and (self.comm_mode == 'PS' or self.comm_mode == "Hybrid"):
  252. assert os.path.isdir(
  253. log_path), 'Need to specify a work directory to save logs.'
  254. self.ps_comm.startRecord(ctypes.c_char_p(bytes(log_path, 'utf-8')))
  255. self.placeholder_to_arr_map = dict()
  256. topo_sort_with_hook(self.my_eval_nodes, self)
  257. class Executor(object):
  258. """Executor computes values for given set of nodes in computation graph."""
  259. def __init__(self, eval_node_dict, config=None, **kargs):
  260. """
  261. Parameters
  262. ----------
  263. eval_node_dict: dict of list of nodes whose values need to be computed.
  264. """
  265. if not isinstance(eval_node_dict, dict):
  266. eval_node_dict = {'default': eval_node_dict}
  267. train_name, val_name = None, None
  268. for k, v in eval_node_dict.items():
  269. if any([isinstance(node, OptimizerOp) for node in v]):
  270. # get the last subexecutor containing optimizer as train for ps op
  271. train_name = k
  272. else:
  273. # get the last subexecutor not containing optimizer as val for ps op
  274. val_name = k
  275. all_eval_nodes = list(set(reduce(add, eval_node_dict.values())))
  276. if config is None:
  277. config = HetuConfig(eval_node_list=all_eval_nodes,
  278. train_name=train_name, val_name=val_name, **kargs)
  279. assert isinstance(
  280. config, HetuConfig), 'Config type %s invalid.' % str(type(config))
  281. self.eval_node_dict = eval_node_dict
  282. self.config = config
  283. self.subexecutor = {k: SubExecutor(
  284. k, v, config) for k, v in eval_node_dict.items()}
  285. self.topo_order = find_topo_sort(config.my_eval_nodes)
  286. self.param_nodes = [node for node in self.topo_order if isinstance(
  287. node, PlaceholderOp) and node.trainable]
  288. self.comm_mode = self.config.comm_mode
  289. self.ps_comm = self.config.ps_comm
  290. self.local_rank = self.config.local_rank
  291. self.rank = self.config.rank
  292. def run(self, name='default', eval_node_list={}, feed_dict={}, convert_to_numpy_ret_vals=False):
  293. return self.subexecutor[name].run(eval_node_list, feed_dict, convert_to_numpy_ret_vals)
  294. @property
  295. def batch_num(self):
  296. assert len(
  297. self.subexecutor) == 1, 'Batch num should be used with only 1 subexecutor.'
  298. return list(self.subexecutor.values())[0].batch_num
  299. def get_batch_num(self, name='default'):
  300. return self.subexecutor[name].batch_num
  301. def save(self, file_path):
  302. assert os.path.isdir(
  303. file_path), 'Need to specify a work directory to save parameters.'
  304. if self.comm_mode in (None, 'AllReduce'):
  305. # when using allreduce, users need to specify the worker whose rank equals 0 to save
  306. for node in self.param_nodes:
  307. np.save(os.path.join(file_path, node.name + '.npy'),
  308. self.config.placeholder_to_arr_map[node].asnumpy())
  309. else:
  310. self.ps_comm.BarrierWorker()
  311. if self.config.rank == 0:
  312. for node in self.param_nodes:
  313. if node.is_embed or self.comm_mode == 'PS':
  314. node.event.sync()
  315. nodeid = ctypes.c_int(node.id)
  316. self.ps_comm.SaveParam(
  317. nodeid, ctypes.c_char_p(bytes(file_path, 'utf-8')))
  318. self.ps_comm.Wait(nodeid)
  319. else:
  320. np.save(os.path.join(file_path, node.name + '.npy'),
  321. self.config.placeholder_to_arr_map[node].asnumpy())
  322. self.ps_comm.BarrierWorker()
  323. def load(self, file_path):
  324. assert os.path.isdir(
  325. file_path), 'Need to specify a work directory to load parameters.'
  326. if self.comm_mode in (None, 'AllReduce'):
  327. for node in self.param_nodes:
  328. self.config.placeholder_to_arr_map[node][:] = np.load(
  329. os.path.join(file_path, node.name + '.npy'))
  330. else:
  331. self.ps_comm.BarrierWorker()
  332. if self.config.rank == 0:
  333. for node in self.param_nodes:
  334. if node.is_embed or self.comm_mode == 'PS':
  335. node.event.sync()
  336. nodeid = ctypes.c_int(node.id)
  337. self.ps_comm.LoadParam(
  338. nodeid, ctypes.c_char_p(bytes(file_path, 'utf-8')))
  339. node.event.update()
  340. self.ps_comm.BarrierWorker()
  341. for node in self.topo_order:
  342. if isinstance(node, PlaceholderOp) and node.trainable and not node.is_embed:
  343. if self.comm_mode == 'PS':
  344. node.event.sync()
  345. nodeid = ctypes.c_int(node.id)
  346. self.ps_comm.Pull(
  347. nodeid, self.config.ps_map[node].handle)
  348. node.event.update()
  349. else:
  350. self.config.placeholder_to_arr_map[node][:] = np.load(
  351. os.path.join(file_path, node.name + '.npy'))
  352. elif isinstance(node, EmbeddingLookUp) and self.config.prefetch:
  353. node.event.sync()
  354. nodeid = ctypes.c_int(node.inputs[0].id)
  355. self.ps_comm.SparsePull(nodeid, node.inputs[1].get_next_arr(
  356. self.name).handle, self.config.ps_map[node.inputs[0]].handle)
  357. node.event.update()
  358. self.ps_comm.BarrierWorker()
  359. def recordLoads(self):
  360. for node in self.config.ps_map:
  361. node.event.sync()
  362. self.ps_comm.getLoads()
  363. def __del__(self):
  364. if self.config.comp_stream is not None:
  365. self.config.comp_stream.sync()
  366. if self.config.h2d_stream is not None:
  367. self.config.h2d_stream.sync()
  368. if self.config.d2h_stream is not None:
  369. self.config.d2h_stream.sync()
  370. if self.config.nccl_stream is not None:
  371. self.config.nccl_stream.sync()
  372. for node in self.param_nodes:
  373. if node.event:
  374. node.event.sync()
  375. if self.comm_mode in ('PS', 'Hybrid'):
  376. worker_finish()
  377. class SubExecutor(object):
  378. def __init__(self, name, eval_node_list, config):
  379. """
  380. Parameters
  381. ----------
  382. eval_node_list: list of nodes whose values need to be computed.
  383. topo_order: list of nodes in topological order
  384. node_to_shape_map: dict from node to shape of the node
  385. node_to_arr_map: dict from node to ndarray.NDArray allocated for node
  386. feed_shapes: shapes of feed_dict from last run(...)
  387. """
  388. self.name = name
  389. self.eval_node_list = eval_node_list
  390. self.config = config
  391. inference = not any([isinstance(node, OptimizerOp)
  392. for node in eval_node_list])
  393. self.inference = inference
  394. if config.p2p_stream:
  395. self.run_results_indices = [eval_node_list.index(
  396. node) if node in eval_node_list else -1 for node in config.my_eval_nodes]
  397. self.eval_node_list = config.my_eval_nodes
  398. self.global_eval_nodes = eval_node_list
  399. if inference == False:
  400. self.topo_order = find_topo_sort(self.eval_node_list)
  401. else: # in inference phase
  402. if self.config.use_sparse_pull == True or self.config.cstable_policy is not None:
  403. # insert ps_sparse_pull_op
  404. self.topo_order = find_topo_sort_inference(self.eval_node_list)
  405. # fetch sparse parameter
  406. fetch_sparse_parameter_value(self.topo_order, self.config)
  407. else:
  408. self.topo_order = find_topo_sort(self.eval_node_list)
  409. # main structures, nodes' shapes and arrays
  410. self.node_to_shape_map = {}
  411. self.node_to_arr_map = {}
  412. # inherit from configurations
  413. self.comm_mode = self.config.comm_mode
  414. self.ps_comm = self.config.ps_comm
  415. self.nccl_comm = self.config.nccl_comm
  416. self.comp_stream = self.config.comp_stream
  417. self.h2d_stream = self.config.h2d_stream
  418. self.d2h_stream = self.config.d2h_stream
  419. self.nccl_stream = self.config.nccl_stream
  420. self.param_psval_map = self.config.infer_ps_map if self.inference else self.config.ps_map
  421. self.use_sparse_pull = self.config.use_sparse_pull
  422. self.cstable_policy = self.config.cstable_policy
  423. self.use_p2p = self.config.p2p_stream is not None
  424. # assisting structures, improve performance
  425. self.need_feed_nodes = []
  426. self.param_nodes = []
  427. self.dataloader_nodes = []
  428. self.computing_nodes = []
  429. for node in self.topo_order:
  430. if isinstance(node, DataloaderOp) or isinstance(node, GNNDataLoaderOp):
  431. self.dataloader_nodes.append(node)
  432. elif isinstance(node, PlaceholderOp):
  433. if node.shape is None:
  434. self.need_feed_nodes.append(node)
  435. elif node.trainable:
  436. self.param_nodes.append(node)
  437. elif not ((self.use_sparse_pull or self.cstable_policy) and isinstance(node, EmbeddingLookUp) and self.config.prefetch):
  438. self.computing_nodes.append(node)
  439. self.batch_num = set([node.get_batch_num(self.name)
  440. for node in self.dataloader_nodes])
  441. assert len(self.batch_num) <= 1, 'Batch num not conform.'
  442. self.batch_num = None if len(
  443. self.batch_num) == 0 else self.batch_num.pop()
  444. self.init_need_allocation = (self.need_feed_nodes == []) and (
  445. self.dataloader_nodes == [])
  446. def update_executor(self, eval_node_list):
  447. self.eval_node_list = eval_node_list
  448. inference = not any([isinstance(node, OptimizerOp)
  449. for node in eval_node_list])
  450. self.inference = inference
  451. if self.config.p2p_stream and self.inference == True:
  452. raise NotImplementedError
  453. if inference == False:
  454. self.topo_order = find_topo_sort(self.eval_node_list)
  455. else: # in inference phase
  456. if self.config.use_sparse_pull == True or self.config.cstable_policy is not None:
  457. # insert ps_sparse_pull_op
  458. self.topo_order = find_topo_sort_inference(self.eval_node_list)
  459. # fetch sparse parameter
  460. fetch_sparse_parameter_value(self.topo_order, self.config)
  461. else:
  462. self.topo_order = find_topo_sort(self.eval_node_list)
  463. # main structures, nodes' shapes and arrays
  464. self.node_to_shape_map = {}
  465. self.node_to_arr_map = {}
  466. # assisting structures, improve performance
  467. self.need_feed_nodes = []
  468. self.param_nodes = []
  469. self.dataloader_nodes = []
  470. self.computing_nodes = []
  471. for node in self.topo_order:
  472. if isinstance(node, DataloaderOp) or isinstance(node, GNNDataLoaderOp):
  473. self.dataloader_nodes.append(node)
  474. elif isinstance(node, PlaceholderOp):
  475. if node.shape is None:
  476. self.need_feed_nodes.append(node)
  477. elif node.trainable:
  478. self.param_nodes.append(node)
  479. elif not ((self.use_sparse_pull or self.cstable_policy) and isinstance(node, EmbeddingLookUp) and self.config.prefetch):
  480. self.computing_nodes.append(node)
  481. self.batch_num = set([node.get_batch_num(self.name)
  482. for node in self.dataloader_nodes])
  483. assert len(self.batch_num) <= 1, 'Batch num not conform.'
  484. self.batch_num = None if len(
  485. self.batch_num) == 0 else self.batch_num.pop()
  486. self.init_need_allocation = (self.need_feed_nodes == []) and (
  487. self.dataloader_nodes == [])
  488. def infer_shape(self, feed_shapes):
  489. """Given shapes of feed_dict nodes, infer shape for all nodes in graph.
  490. Implementation note:
  491. Iteratively calls node.infer_shape to infer shapes.
  492. Node shapes stored in self.node_to_shape_map.
  493. Parameters
  494. ----------
  495. feed_shapes: node->shapes mapping for feed_dict nodes.
  496. """
  497. self.node_to_shape_map = {}
  498. for node in self.topo_order:
  499. if node in feed_shapes:
  500. self.node_to_shape_map[node] = tuple(feed_shapes[node])
  501. else:
  502. input_shapes = [self.node_to_shape_map[n] for n in node.inputs]
  503. cur_shape = node.infer_shape(input_shapes)
  504. self.node_to_shape_map[node] = cur_shape if cur_shape is None else tuple(
  505. cur_shape)
  506. def memory_plan(self):
  507. """Allocates ndarray.NDArray for every node except feed_dict nodes.
  508. Parameters
  509. ----------
  510. """
  511. for node, shape in self.node_to_shape_map.items():
  512. if isinstance(node, PlaceholderOp):
  513. if self.config.placeholder_to_arr_map[node] is not None:
  514. self.node_to_arr_map[node] = self.config.placeholder_to_arr_map[node]
  515. elif node not in self.node_to_arr_map:
  516. self.node_to_arr_map[node] = None
  517. elif not isinstance(node, DataloaderOp) and not isinstance(node, GNNDataLoaderOp):
  518. # add for OptimizerOp and ParameterServerOp
  519. if shape is None:
  520. self.node_to_arr_map[node] = None
  521. continue
  522. if isinstance(node, (EmbeddingLookUp_Gradient, DataD2HSparseOp)):
  523. self.node_to_arr_map[node] = ndarray.IndexedSlices(
  524. dense_shape=shape)
  525. continue
  526. if isinstance(node, EmbeddingLookUp) and (self.use_sparse_pull or self.cstable_policy) and self.config.prefetch:
  527. self.node_to_arr_map[node] = self.param_psval_map[node.inputs[0]]
  528. continue
  529. if node.on_gpu:
  530. if node.inplace:
  531. self.node_to_arr_map[node] = ndarray.NDArray(None)
  532. elif self.inference and isinstance(node, DropoutOp):
  533. self.node_to_arr_map[node] = self.node_to_arr_map[node.inputs[0]]
  534. else:
  535. self.node_to_arr_map[node] = ndarray.empty(
  536. shape, ctx=node.ctx)
  537. else:
  538. self.node_to_arr_map[node] = ndarray.empty(
  539. shape, ctx=node.ctx)
  540. def run(self, eval_node_list={}, feed_dict={}, convert_to_numpy_ret_vals=False):
  541. """
  542. Parameters
  543. ----------
  544. feed_dict: a dictionary of node->np.ndarray supplied by user.
  545. convert_to_numpy_ret_vals: whether to convert ret vals to np.array
  546. Returns
  547. -------
  548. A list of values for nodes in eval_node_list. NDArray or np.ndarray.
  549. """
  550. assert len(feed_dict) == len(
  551. self.need_feed_nodes) or self.use_p2p, 'Feed dict invalid.'
  552. if eval_node_list != {} and eval_node_list != self.eval_node_list:
  553. self.update_executor(eval_node_list)
  554. feed_shapes = {}
  555. need_reallocation = self.init_need_allocation
  556. # get feed in values
  557. for node, value in feed_dict.items():
  558. if self.use_p2p and node not in self.need_feed_nodes:
  559. continue
  560. assert node in self.need_feed_nodes, 'Only allow feed in PlaceholderOp with no values, here got %s:%s.' % (
  561. str(type(node)), node.name)
  562. local_shape = tuple(value.shape)
  563. local_realloc = local_shape != self.node_to_shape_map.get(
  564. node, None)
  565. need_reallocation = need_reallocation or local_realloc
  566. if node.on_cpu:
  567. assert isinstance(value, (np.ndarray, spmatrix, ndarray.NDArray)), \
  568. "feed_dict value type not supported"
  569. if isinstance(value, np.ndarray):
  570. if local_realloc:
  571. self.node_to_arr_map[node] = ndarray.empty(
  572. local_shape, ctx=node.ctx)
  573. self.node_to_arr_map[node][:] = value
  574. else:
  575. self.node_to_arr_map[node] = value
  576. else:
  577. if isinstance(value, np.ndarray):
  578. if local_realloc:
  579. self.node_to_arr_map[node] = ndarray.array(
  580. value, ctx=node.ctx)
  581. else:
  582. self.node_to_arr_map[node][:] = value
  583. elif isinstance(value, spmatrix):
  584. value = coo_matrix(value)
  585. value = ndarray.sparse_array(value.data,
  586. (value.row, value.col), shape=local_shape, ctx=node.ctx)
  587. self.node_to_arr_map[node] = value
  588. elif isinstance(value, ndarray.NDArray):
  589. if value.ctx == node.ctx:
  590. self.node_to_arr_map[node] = value
  591. else:
  592. if local_realloc:
  593. self.node_to_arr_map[node] = ndarray.empty(
  594. local_shape, ctx=node.ctx)
  595. else:
  596. self.node_to_arr_map[node][:] = value
  597. elif isinstance(value, ndarray.ND_Sparse_Array):
  598. self.node_to_arr_map[node] = value
  599. else:
  600. assert False, "feed_dict value type not supported"
  601. feed_shapes[node] = local_shape
  602. # get dataloader values
  603. for node in self.dataloader_nodes:
  604. local_shape = node.get_cur_shape(self.name)
  605. local_realloc = local_shape != self.node_to_shape_map.get(
  606. node, None)
  607. need_reallocation = need_reallocation or local_realloc
  608. self.node_to_arr_map[node] = node.get_arr(self.name)
  609. feed_shapes[node] = local_shape
  610. # reallocation, infer shapes and allocate memory
  611. if need_reallocation:
  612. self.init_need_allocation = False
  613. self.infer_shape(feed_shapes)
  614. self.memory_plan()
  615. # computing
  616. for node in self.computing_nodes:
  617. if node.on_cpu and isinstance(self.node_to_arr_map[node], ndarray.NDArray):
  618. if DNNL_LIB['cpu_ArraySet'] and not isinstance(node, DataD2HOp):
  619. cpu_array_set(self.node_to_arr_map[node], 0.0)
  620. else:
  621. # here we suppose not using DNNL_LIB
  622. # self.node_to_arr_map[node][:] = np.zeros(self.node_to_shape_map[node]).astype(np.float32)
  623. pass
  624. input_vals = [self.node_to_arr_map[n] for n in node.inputs]
  625. node_val = self.node_to_arr_map[node]
  626. for n in node.inputs:
  627. if n.event:
  628. n.event.sync()
  629. if isinstance(node, (ParameterServerCommunicateOp, ParameterServerSparsePullOp)):
  630. # Here we use d2h stream in ps op, since the stream is used for d2h data transfer.
  631. # Please take care at this part.
  632. node.compute(input_vals, node_val, self.d2h_stream)
  633. elif isinstance(node, AllReduceCommunicateOp):
  634. node.compute(input_vals, node_val, self.nccl_stream)
  635. elif isinstance(node, DataH2DOp):
  636. node.compute(input_vals, node_val, self.h2d_stream)
  637. elif isinstance(node, (DataD2HOp, DataD2HSparseOp)):
  638. node.compute(input_vals, node_val, self.d2h_stream)
  639. elif isinstance(node, (PipelineSendOp, PipelineReceiveOp)):
  640. node.compute(input_vals, node_val)
  641. elif isinstance(node, (DropoutOp, Batch_NormalizationOp, Layer_NormalizationOp)):
  642. node.compute(input_vals, node_val,
  643. self.comp_stream, inference=self.inference)
  644. if isinstance(node.event, Event):
  645. # for d2h op / eval nodes / nodes before [allreduce or ps nodes or pipelinesend nodes]
  646. node.event.record(self.comp_stream)
  647. else:
  648. node.compute(input_vals, node_val, self.comp_stream)
  649. if isinstance(node.event, Event):
  650. # for d2h op / eval nodes / nodes before [allreduce or ps nodes or pipelinesend nodes]
  651. node.event.record(self.comp_stream)
  652. for n in self.eval_node_list:
  653. # every node in eval_node_list should have an event (except dataloader/optimizer...)
  654. if n.event:
  655. n.event.sync()
  656. # get results
  657. results = [self.node_to_arr_map[n] for n in self.eval_node_list]
  658. if convert_to_numpy_ret_vals:
  659. for i in range(len(results)):
  660. if results[i] is not None:
  661. results[i] = results[i].asnumpy()
  662. # remap to original order in model parallel
  663. if self.use_p2p:
  664. new_results = [None for _ in self.global_eval_nodes]
  665. for i, j in enumerate(self.run_results_indices):
  666. new_results[j] = results[i]
  667. results = new_results
  668. return results
  669. def gradients(output_node, node_list, insert_grad=None):
  670. """Take gradient of output node with respect to each node in node_list.
  671. Parameters
  672. ----------
  673. output_node: output node that we are taking derivative of.
  674. node_list: list of nodes that we are taking derivative wrt.
  675. insert_grad: used to assign gradient to output_node in model parallel.
  676. Returns
  677. -------
  678. A list of gradient values, one for each node in node_list respectively.
  679. """
  680. if isinstance(output_node, list):
  681. node_to_output_grads_list = {
  682. output_node[i]: [OnesLike.oneslike_op(output_node[i])] if insert_grad is None
  683. else [insert_grad[i]] for i in range(len(output_node))
  684. }
  685. else:
  686. node_to_output_grads_list = {
  687. output_node: [OnesLike.oneslike_op(output_node)] if insert_grad is None else [
  688. insert_grad]
  689. }
  690. output_node = [output_node]
  691. node_to_output_grad = {}
  692. # Traverse forward graph in reverse topological order
  693. reverse_topo_order = reversed(find_topo_sort(output_node))
  694. for node in reverse_topo_order:
  695. # here the ctx for embedding lookup is a workaround
  696. # TODO: when implement PS strategy for context semantics, modify here
  697. if isinstance(node, EmbeddingLookUp):
  698. output_grad = sum_node_list(
  699. node_to_output_grads_list[node], node_to_output_grads_list[node][0].raw_ctx)
  700. else:
  701. output_grad = sum_node_list(
  702. node_to_output_grads_list[node], node.raw_ctx)
  703. if output_grad is None:
  704. for n in node.inputs:
  705. if n not in node_to_output_grads_list:
  706. node_to_output_grads_list[n] = []
  707. continue
  708. node_to_output_grad[node] = output_grad
  709. input_grads_list = node.gradient(output_grad)
  710. for i in range(len(node.inputs)):
  711. if node.inputs[i] not in node_to_output_grads_list:
  712. node_to_output_grads_list[node.inputs[i]] = []
  713. # Calculate partial adjoint for input nodes.
  714. node_to_output_grads_list[node.inputs[i]].append(
  715. input_grads_list[i])
  716. grad_node_list = [node_to_output_grad[node] for node in node_list]
  717. return grad_node_list
  718. ##################
  719. # Helper Methods #
  720. ##################
  721. def topo_sort_with_hook(node_list, config):
  722. visited = set()
  723. for node in node_list:
  724. topo_sort_dfs_with_hook(node, visited, config)
  725. def topo_sort_dfs_with_hook(node, visited, config):
  726. if node in visited:
  727. return
  728. visited.add(node)
  729. node.backward_hook(config)
  730. # move param from node to config
  731. if isinstance(node, PlaceholderOp):
  732. config.placeholder_to_arr_map[node] = node.tensor_value
  733. node.tensor_value = None
  734. for n in node.inputs:
  735. topo_sort_dfs_with_hook(n, visited, config)
  736. node.forward_hook(config)
  737. def find_topo_sort(node_list):
  738. """Given a list of nodes, return a topo ordering of nodes ending in them.
  739. A simple algorithm is to do a post-order DFS traversal on the given nodes,
  740. going backwards based on input edges. Since a node is added to the ordering
  741. after all its predecessors are traversed due to post-order DFS, we get a
  742. topological sort.
  743. """
  744. visited = set()
  745. topo_order = []
  746. for node in node_list:
  747. topo_sort_dfs(node, visited, topo_order)
  748. return topo_order
  749. def topo_sort_dfs(node, visited, topo_order):
  750. """Post-order DFS"""
  751. if node in visited:
  752. return
  753. visited.add(node)
  754. for n in node.inputs:
  755. topo_sort_dfs(n, visited, topo_order)
  756. topo_order.append(node)
  757. def find_topo_sort_inference(node_list):
  758. topo_order = find_topo_sort(node_list)
  759. embedding_list = list()
  760. embedding_outputs = dict()
  761. embedding_cnt = dict()
  762. for node in topo_order:
  763. if isinstance(node, EmbeddingLookUp):
  764. embedding_outputs[node] = list()
  765. embedding_cnt[node] = 0
  766. embedding_list.append(node)
  767. else:
  768. for input_node in node.inputs:
  769. if isinstance(input_node, EmbeddingLookUp):
  770. embedding_outputs[input_node].append(node)
  771. embedding_cnt[input_node] += 1
  772. topo_order_inference = list()
  773. for node in topo_order:
  774. topo_order_inference.append(node)
  775. for embedding in embedding_list:
  776. if node in embedding_outputs[embedding]:
  777. embedding_cnt[embedding] -= 1
  778. if embedding_cnt[embedding] == 0:
  779. topo_order_inference.append(parameterServerSparsePull_op(
  780. embedding, embedding_outputs[embedding]))
  781. embedding_list.remove(embedding)
  782. return topo_order_inference
  783. def fetch_sparse_parameter_value(node_list, config):
  784. for node in node_list:
  785. if isinstance(node, ParameterServerSparsePullOp):
  786. node.forward_hook(config)
  787. def fetch_dense_parameter_value(node_list, config):
  788. assert config.comm_mode in ('PS', 'Hybrid')
  789. topo_order = find_topo_sort(node_list)
  790. val_list = []
  791. # get var list
  792. for node in topo_order:
  793. if isinstance(node, PlaceholderOp) and node.trainable:
  794. val_list.append(node)
  795. for node in val_list:
  796. if config.use_sparse_pull and node.is_embed:
  797. continue
  798. else:
  799. pull_val = ndarray.empty(node.shape, ctx=ndarray.cpu(0))
  800. config.ps_comm.Pull(node.id, pull_val.handle)
  801. config.infer_ps_map[node] = pull_val
  802. config.placeholder_to_arr_map[node] = pull_val
  803. node.event.update()
  804. def sum_node_list(node_list, ctx):
  805. """Custom sum func to avoid creating redundant nodes in Python sum func."""
  806. node_list = [n for n in node_list if n is not None]
  807. if node_list == []:
  808. return None
  809. sum_node = node_list[0]
  810. for n in node_list[1:]:
  811. sum_node = add_op(sum_node, n, ctx=ctx)
  812. return sum_node