You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

DataTransfer.py 4.8 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. from __future__ import absolute_import
  2. from .Node import Op
  3. from .. import ndarray
  4. from .. import stream
  5. from .EmbeddingLookUp import EmbeddingLookUp_Gradient
  6. class DataH2DOp(Op):
  7. # not support sparse matrix!!!
  8. # for sparse matrix, please set Variable's ctx to gpu and pass value in feed_dict
  9. def __init__(self, node_A, ctx):
  10. super().__init__(DataH2DOp, [node_A], ctx)
  11. assert ndarray.is_gpu_ctx(ctx)
  12. assert not ndarray.is_gpu_ctx(node_A.ctx)
  13. self.event = None
  14. self.on_cpu = False
  15. self.on_gpu = True
  16. def compute(self, input_vals, output_val, stream_handle=None):
  17. if stream_handle:
  18. if self.event is None:
  19. self.event = stream.create_event_handle(self.ctx)
  20. output_val.async_h2d(input_vals[0], stream_handle, self.event)
  21. else:
  22. input_vals[0].copyto(output_val)
  23. def gradient(self, output_grad):
  24. if isinstance(output_grad, EmbeddingLookUp_Gradient):
  25. return [datad2h_sparse_op(output_grad)]
  26. else:
  27. return [datad2h_op(output_grad)]
  28. def infer_shape(self, input_shapes):
  29. assert len(input_shapes) == 1
  30. return input_shapes[0]
  31. def forward_hook(self, config):
  32. pass
  33. def backward_hook(self, config):
  34. pass
  35. class DataD2HOp(Op):
  36. def __init__(self, node_A):
  37. assert not isinstance(node_A, EmbeddingLookUp_Gradient)
  38. super().__init__(DataD2HOp, [node_A], ndarray.cpu(0))
  39. assert ndarray.is_gpu_ctx(node_A.ctx)
  40. self.event = None
  41. self.on_cpu = True
  42. self.on_gpu = False
  43. def compute(self, input_vals, output_val, stream_handle=None):
  44. if stream_handle:
  45. if self.event is None:
  46. self.event = stream.create_event_handle(self.inputs[0].ctx)
  47. output_val.async_d2h(input_vals[0], stream_handle, self.event)
  48. else:
  49. input_vals[0].copyto(output_val)
  50. def gradient(self, output_grad):
  51. return [datah2d_op(output_grad, ctx=self.inputs[0].ctx)]
  52. def infer_shape(self, input_shapes):
  53. assert len(input_shapes) == 1
  54. return input_shapes[0]
  55. def forward_hook(self, config):
  56. pass
  57. def backward_hook(self, config):
  58. pass
  59. class DataD2HSparseOp(Op):
  60. # here sparse means indexed slices
  61. def __init__(self, node_A):
  62. assert isinstance(node_A, EmbeddingLookUp_Gradient)
  63. super().__init__(DataD2HSparseOp, [node_A], ndarray.cpu(0))
  64. assert ndarray.is_gpu_ctx(node_A.ctx)
  65. self.event = None
  66. self.on_cpu = True
  67. self.on_gpu = False
  68. def compute(self, input_vals, output_val, stream_handle=None):
  69. assert isinstance(input_vals[0], ndarray.IndexedSlices)
  70. assert isinstance(output_val, ndarray.IndexedSlices)
  71. # TODO: include all these parts into memory allocation management!!!
  72. # TODO: also consider how to deduplicate
  73. if output_val.indices is None or output_val.indices.shape != input_vals[0].indices.shape:
  74. output_val.indices = ndarray.empty(
  75. input_vals[0].indices.shape, ctx=ndarray.cpu(0))
  76. output_val.values = ndarray.empty(
  77. input_vals[0].values.shape, ctx=ndarray.cpu(0))
  78. if stream_handle:
  79. if self.event is None:
  80. self.event = stream.create_event_handle(self.inputs[0].ctx)
  81. output_val.indices.async_d2h(
  82. input_vals[0].indices, stream_handle, self.event)
  83. output_val.values.async_d2h(
  84. input_vals[0].values, stream_handle, self.event)
  85. else:
  86. input_vals[0].indices.copyto(output_val.indices)
  87. input_vals[0].values.copyto(output_val.values)
  88. def gradient(self, output_grad):
  89. raise NotImplementedError
  90. def infer_shape(self, input_shapes):
  91. assert len(input_shapes) == 1
  92. return input_shapes[0]
  93. def forward_hook(self, config):
  94. pass
  95. def backward_hook(self, config):
  96. pass
  97. def datah2d_op(node, ctx):
  98. """Transfer data from host(CPU) to device(GPU).
  99. Parameters:
  100. ----
  101. node : Node
  102. Input variable.
  103. Returns:
  104. ----
  105. A new Node instance created by Op.
  106. """
  107. return DataH2DOp(node, ctx=ctx)
  108. def datad2h_op(node):
  109. """Transfer data from device(GPU) to host(CPU).
  110. Parameters:
  111. ----
  112. node : Node
  113. Input variable.
  114. Returns:
  115. ----
  116. A new Node instance created by Op.
  117. """
  118. return DataD2HOp(node)
  119. def datad2h_sparse_op(node):
  120. """Transfer sparse data from device(GPU) to host(CPU).
  121. Parameters:
  122. ----
  123. node : Node
  124. Input variable.
  125. Returns:
  126. ----
  127. A new Node instance created by Op.
  128. """
  129. return DataD2HSparseOp(node)