You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataset_helper.py 12 kB

5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """Dataset help for minddata dataset"""
  16. import math
  17. import os
  18. from mindspore._checkparam import Validator
  19. from .. import context, nn
  20. from ._utils import _exec_datagraph, _get_types_and_shapes, _construct_tensor_list
  21. from ..nn.wrap import GetNextSingleOp
  22. from ..parallel._utils import _get_device_num, _get_global_rank, _need_to_full, _to_full_shapes
  23. from ..ops import operations as P
  24. def _send_data(dataset, epoch_num):
  25. """Engine dataset to write data to tdt queue."""
  26. if not hasattr(dataset, '__has_sent__'):
  27. exec_dataset = dataset.__TRANSFER_DATASET__
  28. exec_dataset.send(epoch_num)
  29. dataset.__has_sent__ = True
  30. def _send_data_no_flag(dataset, epoch_num):
  31. """Engine dataset to write data to tdt queue directly."""
  32. exec_dataset = dataset.__TRANSFER_DATASET__
  33. exec_dataset.send(epoch_num)
  34. def connect_network_with_dataset(network, dataset_helper):
  35. """
  36. Connect the `network` with dataset in `dataset_helper`.
  37. This function wraps the input network with 'GetNext' so that the data can be fetched automatically from the
  38. data channel corresponding to the 'queue_name' and passed to the input network during forward computation.
  39. Note:
  40. In the case of running the network on Ascend in graph mode, this function will wrap the input network with
  41. 'GetNext', in other cases, the input network will be returned with no change.
  42. The 'GetNext' is required to get data only in sink mode, so this function is not applicable to no-sink mode.
  43. Args:
  44. network (Cell): The training network for dataset.
  45. dataset_helper(DatasetHelper): A class to process the MindData dataset, it provides the type, shape and queue
  46. name of the dataset to wrap the `GetNext`.
  47. Outputs:
  48. Cell, a new network wrapped with 'GetNext' in the case of running the task on Ascend in graph mode, otherwise
  49. it is the input network.
  50. Examples:
  51. >>> # call create_dataset function to create a regular dataset, refer to mindspore.dataset
  52. >>> train_dataset = create_dataset()
  53. >>> dataset_helper = mindspore.DatasetHelper(train_dataset, dataset_sink_mode=True)
  54. >>> net = Net()
  55. >>> net_with_get_next = connect_network_with_dataset(net, dataset_helper)
  56. """
  57. class _DataWrapper(nn.Cell):
  58. """
  59. Wraps the input network with a dataset which automatically fetches data with 'GetNext' function from the
  60. dataset channel 'queue_name' and performs the forward computation.
  61. """
  62. def __init__(self, network, dataset_types, dataset_shapes, queue_name):
  63. super(_DataWrapper, self).__init__(auto_prefix=False, flags=network.get_flags())
  64. # Also copy the flag in `network` construct
  65. flags = getattr(network.__class__.construct, "_mindspore_flags", {})
  66. self.add_flags(**flags)
  67. self.get_next = P.GetNext(dataset_types, dataset_shapes, len(dataset_types), queue_name)
  68. self.network = network
  69. def construct(self):
  70. outputs = self.get_next()
  71. return self.network(*outputs)
  72. dataset_iter = dataset_helper.iter
  73. dataset = dataset_iter.dataset
  74. if isinstance(dataset_iter, _DatasetIterNormal):
  75. raise RuntimeError("Dataset should be connected with network only in sink mode.")
  76. if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" and \
  77. not context.get_context("enable_ge"):
  78. dataset.__ME_INITED__ = True
  79. dataset_types, dataset_shapes = dataset_helper.types_shapes()
  80. queue_name = dataset.__TRANSFER_DATASET__.queue_name
  81. network = _DataWrapper(network, dataset_types, dataset_shapes, queue_name)
  82. return network
  83. class DatasetHelper:
  84. """
  85. DatasetHelper is a class to process the MindData dataset and it provides the information of dataset.
  86. According to different contexts, change the iterations of dataset and use the same iteration for loop in different
  87. contexts.
  88. Note:
  89. The iteration of DatasetHelper will provide one epoch data.
  90. Args:
  91. dataset (DataSet): The training dataset iterator.
  92. dataset_sink_mode (bool): If true use GetNext to fetch the data, or else feed the data from host. Default: True.
  93. sink_size (int): Control the amount of data in each sink.
  94. If sink_size=-1, sink the complete dataset for each epoch.
  95. If sink_size>0, sink sink_size data for each epoch. Default: -1.
  96. epoch_num (int): Control the number of epoch data to send. Default: 1.
  97. Examples:
  98. >>> dataset_helper = DatasetHelper(dataset)
  99. >>> for inputs in dataset_helper:
  100. >>> outputs = network(*inputs)
  101. """
  102. def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1):
  103. dataset_sink_mode = Validator.check_bool(dataset_sink_mode)
  104. Validator.check_is_int(sink_size)
  105. if sink_size < -1 or sink_size == 0:
  106. raise ValueError("The sink_size must be -1 or positive, but got sink_size {}.".format(sink_size))
  107. if dataset_sink_mode:
  108. if context.get_context("enable_ge"):
  109. iterclass = _DatasetIterGE
  110. else:
  111. if context.get_context("device_target") == "Ascend":
  112. iterclass = _DatasetIterMSLoopSink
  113. elif context.get_context("device_target") == "GPU":
  114. ms_role = os.getenv("MS_ROLE")
  115. if ms_role in ("MS_PSERVER", "MS_SCHED"):
  116. iterclass = _DatasetIterPSLite
  117. else:
  118. iterclass = _DatasetIterMS
  119. elif context.get_context("device_target") == "CPU":
  120. raise RuntimeError("Currently dataset sink mode is not supported when the device target is CPU.")
  121. self.iter = iterclass(dataset, sink_size, epoch_num)
  122. else:
  123. iterclass = _DatasetIterNormal
  124. self.iter = iterclass(dataset, epoch_num=epoch_num)
  125. def __iter__(self):
  126. return self.iter.__iter__()
  127. # A temp solution for loop sink. Delete later
  128. def types_shapes(self):
  129. """Get the types and shapes from dataset on the current configuration."""
  130. return self.iter.types_shapes()
  131. def sink_size(self):
  132. """Get sink_size for each iteration."""
  133. return self.iter.get_sink_size()
  134. def stop_send(self):
  135. """Free up resources about data sink."""
  136. self.iter.stop_send()
  137. def continue_send(self):
  138. """continue send data to device at the beginning of epoch."""
  139. self.iter.continue_send()
  140. class _DatasetIter:
  141. """Base iter for dataset helper"""
  142. def __init__(self, dataset, sink_size, epoch_num):
  143. self.dataset = dataset
  144. self.sink_size = sink_size
  145. self.sink_count = 1
  146. if not hasattr(dataset, '__TRANSFER_DATASET__'):
  147. if hasattr(dataset, '__loop_size__'):
  148. self.sink_size = dataset.__loop_size__
  149. dataset.__TRANSFER_DATASET__ = _exec_datagraph(dataset, self.sink_size)
  150. if not hasattr(dataset, '__no_send__'):
  151. _send_data(dataset, epoch_num)
  152. else:
  153. _send_data_no_flag(dataset, epoch_num)
  154. self.stop_send = dataset.__TRANSFER_DATASET__.stop_send
  155. self.continue_send = dataset.__TRANSFER_DATASET__.continue_send
  156. self.dataset_types, self.dataset_shapes = _get_types_and_shapes(dataset)
  157. def __iter__(self):
  158. self.index = 0
  159. return self
  160. def __next__(self):
  161. if self.index >= self.sink_count:
  162. raise StopIteration()
  163. self.index += 1
  164. return self.op()
  165. def types_shapes(self):
  166. return self.dataset_types, self.dataset_shapes
  167. def get_sink_count(self, dataset):
  168. sink_count = 1
  169. if hasattr(dataset, '__loop_size__'):
  170. loop_size = dataset.__loop_size__
  171. if loop_size <= dataset.get_dataset_size() and dataset.get_dataset_size() % loop_size != 0:
  172. raise ValueError(f'Dataset size {dataset.get_dataset_size()} and '
  173. f'sink_size {loop_size} are not matched.')
  174. sink_count = math.ceil(dataset.get_dataset_size() / loop_size)
  175. return sink_count
  176. def get_sink_size(self):
  177. """get sink_size to device"""
  178. sink_size = 1
  179. if hasattr(self.dataset, '__loop_size__'):
  180. sink_size = self.dataset.__loop_size__
  181. else:
  182. if context.get_context("enable_ge") or context.get_context("device_target") == "Ascend":
  183. if self.sink_size > 0:
  184. sink_size = self.sink_size
  185. else:
  186. sink_size = self.dataset.get_dataset_size()
  187. return sink_size
  188. class _DatasetIterGE(_DatasetIter):
  189. """Iter for GE."""
  190. def __init__(self, dataset, sink_size, epoch_num):
  191. super().__init__(dataset, sink_size, epoch_num)
  192. self.sink_count = self.get_sink_count(dataset)
  193. batch_expand_num = 1
  194. if _need_to_full():
  195. batch_expand_num = _get_device_num()
  196. tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num)
  197. def op():
  198. return tensor_list_run
  199. self.op = op
  200. class _DatasetIterMSLoopSink(_DatasetIter):
  201. """Iter for context (device_target=Ascend)"""
  202. def __init__(self, dataset, sink_size, epoch_num):
  203. super().__init__(dataset, sink_size, epoch_num)
  204. self.sink_count = self.get_sink_count(dataset)
  205. ms_role = os.getenv("MS_ROLE")
  206. if ms_role in ("MS_PSERVER", "MS_SCHED"):
  207. self.sink_count = 1
  208. # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, and not using full_batch,
  209. # use a complete tensor to compile, and slice tensor to run. The batch dimension of tensors for
  210. # compile is device_number times the batch dimension of tensors for run. Now only support LoopSink.
  211. if _need_to_full():
  212. device_num = _get_device_num()
  213. self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num)
  214. def op():
  215. return tuple()
  216. self.op = op
  217. class _DatasetIterMS(_DatasetIter):
  218. """Iter for MS(enable_loop_sink=False)."""
  219. def __init__(self, dataset, sink_size, epoch_num):
  220. super().__init__(dataset, sink_size, epoch_num)
  221. if sink_size > 0:
  222. self.sink_count = sink_size
  223. else:
  224. self.sink_count = dataset.get_dataset_size()
  225. queue_name = dataset.__TRANSFER_DATASET__.queue_name
  226. self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name)
  227. class _DatasetIterPSLite(_DatasetIter):
  228. """Iter for context (device_target=GPU) on MS_PSERVER or MS_SCHED"""
  229. def __init__(self, dataset, sink_size, epoch_num):
  230. super().__init__(dataset, sink_size, epoch_num)
  231. self.sink_count = 1
  232. self.sink_size = 1
  233. self.op = None
  234. def op():
  235. return _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num=1)
  236. self.op = op
  237. class _DatasetIterNormal:
  238. """Iter for normal(non sink) mode, feed the data from host."""
  239. def __init__(self, dataset, epoch_num=-1):
  240. self.dataset = dataset
  241. self.device_num = _get_device_num()
  242. self.global_rank = _get_global_rank()
  243. self.iter = self.dataset.create_tuple_iterator(num_epochs=epoch_num)
  244. def __iter__(self):
  245. return self
  246. def __next__(self):
  247. data = self.iter.__next__()
  248. return data
  249. __all__ = ["DatasetHelper", "connect_network_with_dataset"]