You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ParameterServerCommunicate.py 14 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. from .. import ndarray
  4. from ..gpu_links import matrix_elementwise_multiply_by_const
  5. from .. import stream
  6. import os
  7. import numpy as np
  8. import ctypes
  9. class ParameterServerCommunicateOp(Op):
  10. def __init__(self, nodeA, parameter, optimizer):
  11. super().__init__(ParameterServerCommunicateOp, [nodeA], nodeA.ctx)
  12. self.on_gpu = ndarray.is_gpu_ctx(self.ctx)
  13. self.on_cpu = not self.on_gpu
  14. self.parameter = parameter
  15. self.optimizer = optimizer
  16. # the optimizer not implemented yet! only SGD is supported, calculate on worker
  17. # the optimizer only support fixed learning rate, no scheduler supported.
  18. # TODO: implement optimizer on Servers(already implemented, not in use) and Caches(not implemented yet)
  19. # TODO: implement learning rate scheduler
  20. self.learning_rate = -optimizer[1][0] / \
  21. int(os.environ['DMLC_NUM_WORKER'])
  22. self.ps_id = ctypes.c_int(self.parameter.id)
  23. self.psevent = None
  24. def _get_event(self, input_val, stream_handle):
  25. if stream_handle:
  26. self.push_val.async_d2h(input_val, stream_handle, self.psevent)
  27. evt = self.psevent.handle
  28. else:
  29. input_val.copyto(self.push_val)
  30. evt = None
  31. return evt
  32. def _compute_asp_prefetch(self, input_vals, output_val, stream_handle=None):
  33. self._mult_lr(input_vals[0], stream_handle)
  34. self._update_event(self._push_pull(input_vals[0], stream_handle))
  35. def _compute_bsp_prefetch(self, input_vals, output_val, stream_handle=None):
  36. self._mult_lr(input_vals[0], stream_handle)
  37. self._wait(self._push(input_vals[0], stream_handle))
  38. self.comm.BarrierWorker()
  39. self._update_event(self._pull())
  40. def _compute_no_prefetch(self, input_vals, output_val, stream_handle=None):
  41. self._mult_lr(input_vals[0], stream_handle)
  42. self._update_event(self._push(input_vals[0], stream_handle))
  43. def _mult_lr_sparse_cpu(self, input_val, stream_handle):
  44. input_val.values[:] = input_val.values.asnumpy() * self.learning_rate
  45. def _mult_lr_dense_cpu(self, input_val, stream_handle):
  46. input_val[:] = input_val.asnumpy() * self.learning_rate
  47. def _mult_lr_dense_gpu(self, input_val, stream_handle):
  48. matrix_elementwise_multiply_by_const(
  49. input_val, self.learning_rate, input_val, stream_handle)
  50. def _push_pull_cache(self, input_val, stream_handle):
  51. return self.cache.embedding_push_pull(
  52. pullkeys=self.dl_node.get_next_arr(self.dl_name), dest=self.sparse_pull_val,
  53. pushkeys=input_val.indices, grads=input_val.values
  54. )
  55. def _push_pull_sparse_cpu(self, input_val, stream_handle):
  56. return self.comm.SSPushPull(self.ps_id, input_val.indices.handle, input_val.values.handle,
  57. self.dl_node.get_next_arr(self.dl_name).handle, self.sparse_pull_val.handle, None)
  58. def _push_pull_halfsparse_cpu(self, input_val, stream_handle):
  59. return self.comm.SDPushPull(self.ps_id, input_val.indices.handle, input_val.values.handle, self.pull_val.handle, None)
  60. def _push_pull_dense_cpu(self, input_val, stream_handle):
  61. return self.comm.DDPushPull(self.ps_id, input_val.handle, self.pull_val.handle, None)
  62. def _push_pull_dense_gpu(self, input_val, stream_handle):
  63. evt = self._get_event(input_val, stream_handle)
  64. return self.comm.DDPushPull(self.ps_id, self.push_val.handle, self.pull_val.handle, evt)
  65. def _push_cache(self, input_val, stream_handle):
  66. return self.cache.embedding_update(input_val.indices, input_val.values)
  67. def _push_sparse_cpu(self, input_val, stream_handle):
  68. return self.comm.SparsePush(self.ps_id, input_val.indices.handle, input_val.values.handle, None)
  69. def _push_dense_cpu(self, input_val, stream_handle):
  70. return self.comm.Push(self.ps_id, input_val.handle, None)
  71. def _push_dense_gpu(self, input_val, stream_handle):
  72. evt = self._get_event(input_val, stream_handle)
  73. return self.comm.Push(self.ps_id, self.push_val.handle, evt)
  74. def _pull_cache(self):
  75. return self.cache.embedding_lookup(self.dl_node.get_next_arr(self.dl_name), self.sparse_pull_val)
  76. def _pull_sparse(self):
  77. return self.comm.SparsePull(self.ps_id, self.dl_node.get_next_arr(self.dl_name).handle, self.sparse_pull_val.handle)
  78. def _pull_dense(self):
  79. return self.comm.Pull(self.ps_id, self.pull_val.handle)
  80. def _wait_cache(self, ts):
  81. ts.wait()
  82. def _wait_ps(self, ts):
  83. self.comm.Wait(self.ps_id)
  84. def _update_event_cache(self, ts):
  85. self.parameter.event.update_ts(ts)
  86. def _update_event_ps(self, ts):
  87. self.parameter.event.update()
  88. def gradient(self, output_grad):
  89. raise NotImplementedError
  90. def infer_shape(self, input_shapes):
  91. return None
  92. def forward_hook(self, config):
  93. # disable inplace if not lazy execution
  94. # previously we use array reshape lazy callback to do this, which is deprecated (not efficient)
  95. self.inputs[0].inplace = False
  96. self.ctx = self.inputs[0].ctx
  97. self.on_gpu = ndarray.is_gpu_ctx(self.ctx)
  98. self.on_cpu = not self.on_gpu
  99. if self.on_gpu and self.inputs[0].event is None:
  100. self.inputs[0].event = stream.create_event_handle(self.ctx)
  101. self.comm = config.ps_comm
  102. node_shape = self.parameter.shape
  103. # using cache
  104. if config.cstable_policy is not None and self.parameter.is_embed:
  105. assert len(node_shape) == 2
  106. from hetu.cstable import CacheSparseTable
  107. self._wait = self._wait_cache
  108. self._update_event = self._update_event_cache
  109. self._mult_lr = self._mult_lr_sparse_cpu
  110. if config.bsp and config.prefetch:
  111. self._push = self._push_cache
  112. self._pull = self._pull_cache
  113. self.compute = self._compute_bsp_prefetch
  114. elif config.prefetch:
  115. self._push_pull = self._push_pull_cache
  116. self.compute = self._compute_asp_prefetch
  117. else:
  118. self._push = self._push_cache
  119. self.compute = self._compute_no_prefetch
  120. limit = node_shape[0] // 10 # TODO: need tuning
  121. # only worker 0 will do the initialization on server,
  122. # this function synchronously initialize meta information and do the initialization,
  123. # ALREADY has barrier!
  124. self.parameter.initializer.init_on_ps(
  125. self.comm, self.ps_id, 2, seed=config.seed + self.ps_id.value, opt=self.optimizer)
  126. self.cache = CacheSparseTable(
  127. limit, node_shape[0], node_shape[1], self.parameter.id, config.cstable_policy, config.cache_bound)
  128. self.parameter.cache = self.cache
  129. if config.prefetch:
  130. self.dl_name = config.train_name
  131. self.dl_node = self.inputs[0].inputs[1]
  132. local_shape = list(self.dl_node.get_cur_shape(self.dl_name))
  133. local_shape.append(node_shape[-1])
  134. self.sparse_pull_val = ndarray.empty(
  135. tuple(local_shape), ctx=ndarray.cpu(0))
  136. self.parameter.event.update_ts(self.cache.embedding_lookup(
  137. self.dl_node.get_next_arr(self.dl_name), self.sparse_pull_val))
  138. config.ps_map[self.parameter] = self.sparse_pull_val
  139. return
  140. # initialize
  141. self_sparse = self.parameter.is_embed and config.use_sparse_pull
  142. if self.on_gpu:
  143. self.push_val = ndarray.empty(node_shape, ctx=ndarray.cpu(0))
  144. if config.d2h_stream:
  145. self.psevent = stream.create_event_handle(self.ctx)
  146. # only worker 0 will do the initialization on server,
  147. # this function synchronously initialize meta information and do the initialization,
  148. # ALREADY has barrier!
  149. self.parameter.initializer.init_on_ps(self.comm, self.ps_id, int(
  150. self.parameter.is_embed), seed=config.seed + self.ps_id.value, opt=self.optimizer)
  151. if self_sparse:
  152. if config.prefetch:
  153. self.dl_name = config.train_name
  154. self.dl_node = self.inputs[0].inputs[1]
  155. local_shape = list(self.dl_node.get_cur_shape(self.dl_name))
  156. local_shape.append(node_shape[-1])
  157. self.sparse_pull_val = ndarray.empty(
  158. tuple(local_shape), ctx=ndarray.cpu(0))
  159. self.comm.SparsePull(self.ps_id, self.dl_node.get_next_arr(
  160. self.dl_name).handle, self.sparse_pull_val.handle)
  161. config.ps_map[self.parameter] = self.sparse_pull_val
  162. self.parameter.event.update()
  163. else:
  164. self.pull_val = ndarray.empty(node_shape, ctx=ndarray.cpu(0))
  165. self.comm.Pull(self.ps_id, self.pull_val.handle)
  166. config.ps_map[self.parameter] = self.pull_val
  167. config.placeholder_to_arr_map[self.parameter] = self.pull_val
  168. self.parameter.event.update()
  169. # config compute function
  170. self._wait = self._wait_ps
  171. self._update_event = self._update_event_ps
  172. if self_sparse:
  173. self._mult_lr = self._mult_lr_sparse_cpu
  174. self._push = self._push_sparse_cpu
  175. self._pull = self._pull_sparse
  176. self._push_pull = self._push_pull_sparse_cpu
  177. elif self.parameter.is_embed:
  178. self._mult_lr = self._mult_lr_sparse_cpu
  179. self._push = self._push_sparse_cpu
  180. self._pull = self._pull_dense
  181. self._push_pull = self._push_pull_halfsparse_cpu
  182. elif self.on_cpu:
  183. self._mult_lr = self._mult_lr_dense_cpu
  184. self._push = self._push_dense_cpu
  185. self._pull = self._pull_dense
  186. self._push_pull = self._push_pull_dense_cpu
  187. else:
  188. self._mult_lr = self._mult_lr_dense_gpu
  189. self._push = self._push_dense_gpu
  190. self._pull = self._pull_dense
  191. self._push_pull = self._push_pull_dense_gpu
  192. if config.bsp and (config.prefetch or not self_sparse):
  193. self.compute = self._compute_bsp_prefetch
  194. elif config.prefetch or not self_sparse:
  195. self.compute = self._compute_asp_prefetch
  196. else:
  197. self.compute = self._compute_no_prefetch
  198. # 只在正向图插入sparse pull的op dense pull的op在init时完成
  199. class ParameterServerSparsePullOp(Op):
  200. def __init__(self, node, deps_node):
  201. super().__init__(ParameterServerSparsePullOp,
  202. [node] + deps_node, node.ctx)
  203. self.on_gpu = ndarray.is_gpu_ctx(self.ctx)
  204. self.on_cpu = not self.on_gpu
  205. self.parameter = node.inputs[0]
  206. self.ps_id = ctypes.c_int(self.parameter.id)
  207. self.psevent = None
  208. def compute(self, input_vals, output_val, stream_handle=None):
  209. comm = self.comm
  210. if self.use_cache_table:
  211. ts = self.cache.embedding_lookup(
  212. self.dl_node.get_next_arr(self.dl_name), self.sparse_pull_val)
  213. self.parameter.event.update_ts(ts)
  214. return
  215. assert self.on_cpu == True
  216. assert isinstance(input_vals[0], ndarray.NDArray)
  217. comm.SparsePull(self.ps_id, self.dl_node.get_next_arr(
  218. self.dl_name).handle, self.sparse_pull_val.handle)
  219. self.parameter.event.update()
  220. def gradient(self, output_grad):
  221. raise NotImplementedError
  222. def infer_shape(self, input_shapes):
  223. return None
  224. def forward_hook(self, config):
  225. self.comm = config.ps_comm
  226. self.use_cache_table = config.cstable_policy is not None
  227. node_shape = self.parameter.shape
  228. assert (
  229. config.use_sparse_pull or self.use_cache_table) and self.parameter.is_embed
  230. self.dl_name = config.val_name
  231. self.dl_node = self.inputs[0].inputs[1]
  232. local_shape = list(self.dl_node.get_cur_shape(self.dl_name))
  233. local_shape.append(node_shape[-1])
  234. self.sparse_pull_val = ndarray.empty(
  235. tuple(local_shape), ctx=ndarray.cpu(0))
  236. config.infer_ps_map[self.parameter] = self.sparse_pull_val
  237. if self.use_cache_table:
  238. self.cache = self.parameter.cache
  239. self.parameter.event.sync()
  240. ts = self.cache.embedding_lookup(
  241. self.dl_node.get_next_arr(self.dl_name), self.sparse_pull_val)
  242. self.parameter.event.update_ts(ts)
  243. else:
  244. self.parameter.event.sync()
  245. self.comm.SparsePull(self.ps_id, self.dl_node.get_next_arr(
  246. self.dl_name).handle, self.sparse_pull_val.handle)
  247. self.parameter.event.update()
  248. def parameterServerCommunicate_op(node, parameter, optimizer):
  249. """Make a new instance of ParameterServerCommunicateOp and call the instance.
  250. Parameters:
  251. ----
  252. node : Node
  253. The Node to do allreduce
  254. parameter: Node
  255. The parameter Node that corresponding to the gradient
  256. learning_rate: float
  257. Adjusted learning rate
  258. Returns:
  259. ----
  260. A new Node instance created by Op.
  261. """
  262. return ParameterServerCommunicateOp(node, parameter, optimizer)
  263. def parameterServerSparsePull_op(parameter, deps_node):
  264. """Make a new instance of ParameterServerCommunicateOp and call the instance.
  265. Parameters:
  266. ----
  267. node : Node
  268. The Node to do Pull data
  269. parameter: Node
  270. The parameter Node that corresponding to the gradient
  271. Returns:
  272. ----
  273. A new Node instance created by Op.
  274. """
  275. return ParameterServerSparsePullOp(parameter, deps_node)