You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dataloader.py 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import collections
  10. import math
  11. import multiprocessing
  12. import platform
  13. import queue
  14. import random
  15. import threading
  16. import time
  17. import numpy as np
  18. from ..logger import get_logger
  19. from ..random.rng import _random_seed_generator
  20. from .collator import Collator
  21. from .dataset import Dataset, StreamDataset
  22. from .sampler import MapSampler, Sampler, SequentialSampler, StreamSampler
  23. from .transform import PseudoTransform, Transform
  24. try:
  25. import thread
  26. except:
  27. import _thread as thread
  28. logger = get_logger(__name__)
  29. GLOBAL_TIMEOUT = 5
  30. class DataLoader:
  31. __initialized = False
  32. def __init__(
  33. self,
  34. dataset: Dataset,
  35. sampler: Sampler = None,
  36. transform: Transform = None,
  37. collator: Collator = None,
  38. num_workers: int = 0,
  39. timeout: int = GLOBAL_TIMEOUT,
  40. divide: bool = False,
  41. ):
  42. r"""
  43. Provides a convenient way to iterate on a given dataset.
  44. `DataLoader` combines a dataset with `sampler`, `transform` and `collator`,
  45. make it flexible to get minibatch continually from a dataset.
  46. :type dataset: Dataset
  47. :param dataset: dataset from which to load the minibatch.
  48. :type sampler: Sampler
  49. :param sampler: defines the strategy to sample data from the dataset.
  50. :type transform: Transform
  51. :param transform: defined the transforming strategy for a sampled batch.
  52. Default: None
  53. :type collator: Collator
  54. :param collator: defined the merging strategy for a transformed batch.
  55. Default: None
  56. :type num_workers: int
  57. :param num_workers: the number of sub-process to load, transform and collate
  58. the batch. ``0`` means using single-process. Default: 0
  59. :type timeout: int
  60. :param timeout: if positive, means the timeout value(second) for collecting a
  61. batch from workers. Default: 0
  62. :type divide: bool
  63. :param divide: define the paralleling strategy in multi-processing mode.
  64. ``True`` means one batch is divided into :attr:`num_workers` pieces, and
  65. the workers will process these pieces parallelly. ``False`` means
  66. different sub-process will process different batch. Default: False
  67. """
  68. if num_workers < 0:
  69. raise ValueError("num_workers should not be negative")
  70. if timeout < 0:
  71. raise ValueError("timeout should not be negative")
  72. if divide and num_workers <= 1:
  73. raise ValueError("divide should not be set to True when num_workers <= 1")
  74. self.dataset = dataset
  75. self.num_workers = num_workers
  76. self.timeout = timeout
  77. self.divide = divide
  78. if isinstance(dataset, StreamDataset):
  79. self.sampler = sampler if sampler else StreamSampler(batch_size=1)
  80. assert isinstance(
  81. self.sampler, StreamSampler
  82. ), "types of dataset and sampler do not match"
  83. else:
  84. assert isinstance(
  85. dataset, Dataset
  86. ), "Can not recognize this kind of dataset: %s" % type(dataset)
  87. self.sampler = (
  88. sampler
  89. if sampler
  90. else SequentialSampler(dataset, batch_size=1, drop_last=False)
  91. )
  92. assert isinstance(
  93. self.sampler, MapSampler
  94. ), "types of dataset and sampler do not match"
  95. if divide:
  96. if self.sampler.batch_size <= self.num_workers:
  97. raise ValueError(
  98. "batch size must not smaller than num_workers in divide mode."
  99. )
  100. elif self.sampler.batch_size % self.num_workers:
  101. logger.warning(
  102. "batch size is not divisible by num_workers, may lose performance in divide mode."
  103. )
  104. if transform is None:
  105. self.transform = PseudoTransform()
  106. else:
  107. self.transform = transform
  108. if collator is None:
  109. self.collator = Collator()
  110. else:
  111. self.collator = collator
  112. self.__initialized = True
  113. def __iter__(self):
  114. if platform.system() == "Windows" and self.num_workers > 0:
  115. print(
  116. "pyarrow.plasma does not support ParallelDataLoader on windows, changing num_workers to be zero"
  117. )
  118. self.num_workers = 0
  119. if isinstance(self.dataset, StreamDataset):
  120. if not self.num_workers:
  121. return _SerialStreamDataLoaderIter(self)
  122. else:
  123. return _ParallelStreamDataLoaderIter(self)
  124. else:
  125. assert isinstance(
  126. self.dataset, Dataset
  127. ), "Can not recognize this kind of dataset: %s" % type(self.dataset)
  128. if not self.num_workers:
  129. return _SerialMapDataLoaderIter(self)
  130. else:
  131. return _ParallelMapDataLoaderIter(self)
  132. def __len__(self):
  133. return len(self.sampler)
  134. class _BaseMapDataLoaderIter:
  135. def __init__(self, loader):
  136. self.dataset = loader.dataset
  137. self.sampler = loader.sampler
  138. self.seed = _random_seed_generator().__next__()
  139. self.transform = loader.transform
  140. self.collator = loader.collator
  141. self.num_workers = loader.num_workers
  142. self.timeout = loader.timeout
  143. self.divide = loader.divide
  144. self.num_processed = 0
  145. def _get_next_batch(self):
  146. raise NotImplementedError
  147. def __len__(self):
  148. return len(self.sampler)
  149. def __iter__(self):
  150. return self
  151. def __next__(self):
  152. if self.num_processed >= len(self):
  153. raise StopIteration
  154. minibatch = self._get_next_batch()
  155. self.num_processed += 1
  156. return minibatch
  157. class _SerialMapDataLoaderIter(_BaseMapDataLoaderIter):
  158. def __init__(self, loader):
  159. super(_SerialMapDataLoaderIter, self).__init__(loader)
  160. self.indices_iter = iter(self.sampler)
  161. def _get_next_batch(self):
  162. indices = next(self.indices_iter)
  163. items = [self.dataset[idx] for idx in indices]
  164. trans_items = self.transform.apply_batch(items)
  165. return self.collator.apply(trans_items)
  166. class _ParallelMapDataLoaderIter(_BaseMapDataLoaderIter):
  167. __initialized = False
  168. def __init__(self, loader):
  169. super(_ParallelMapDataLoaderIter, self).__init__(loader)
  170. self.task_queues = [
  171. multiprocessing.Queue(maxsize=2) for _ in range(self.num_workers)
  172. ]
  173. self.feed_batch_idx = multiprocessing.Value("i", 0)
  174. self.target_batch_idx = multiprocessing.Value("i", 0)
  175. self.shutdown_flag = multiprocessing.Value("i", 0)
  176. self.trans_data_queues = [
  177. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  178. ]
  179. # use shared-memory queue implemented by pyarrow plasma store.
  180. from ._queue import PlasmaShmQueue
  181. self.batch_queue = PlasmaShmQueue(maxsize=2)
  182. self.task_feeding_worker = multiprocessing.Process(
  183. target=_task_feeding_loop,
  184. args=(
  185. iter(self.sampler),
  186. self.task_queues,
  187. self.num_workers,
  188. self.divide,
  189. self.shutdown_flag,
  190. self.feed_batch_idx,
  191. ),
  192. daemon=True,
  193. )
  194. self.task_feeding_worker.start()
  195. self.workers = []
  196. for worker_id in range(self.num_workers):
  197. worker = multiprocessing.Process(
  198. target=_worker_loop,
  199. args=(
  200. self.dataset,
  201. self.task_queues[worker_id],
  202. self.trans_data_queues[worker_id],
  203. self.transform,
  204. self.seed + worker_id + 1,
  205. self.shutdown_flag,
  206. ),
  207. daemon=True,
  208. )
  209. worker.start()
  210. self.workers.append(worker)
  211. if self.divide:
  212. self.data_collecting_worker = multiprocessing.Process(
  213. target=_data_gathering_loop,
  214. args=(
  215. self.trans_data_queues,
  216. self.batch_queue,
  217. self.collator,
  218. len(self),
  219. self.num_workers,
  220. self.shutdown_flag,
  221. self.target_batch_idx,
  222. ),
  223. daemon=True,
  224. )
  225. else:
  226. self.data_collecting_worker = multiprocessing.Process(
  227. target=_data_selecting_loop,
  228. args=(
  229. self.trans_data_queues,
  230. self.batch_queue,
  231. self.collator,
  232. len(self),
  233. self.num_workers,
  234. self.shutdown_flag,
  235. self.target_batch_idx,
  236. ),
  237. daemon=True,
  238. )
  239. self.data_collecting_worker.start()
  240. self.__initialized = True
  241. def _check_workers(self):
  242. # Check the status of each worker.
  243. if not self.data_collecting_worker.is_alive():
  244. exitcode = self.task_feeding_worker.exitcode
  245. if exitcode != 0:
  246. raise RuntimeError("data collecting worker died. {}".format(exitcode))
  247. if not self.task_feeding_worker.is_alive():
  248. exitcode = self.task_feeding_worker.exitcode
  249. if exitcode != 0:
  250. raise RuntimeError("task feeding worker died. {}".format(exitcode))
  251. for worker_id, worker in enumerate(self.workers):
  252. if not worker.is_alive():
  253. exitcode = worker.exitcode
  254. if exitcode != 0:
  255. raise RuntimeError("worker:{} died. {}".format(worker_id, exitcode))
  256. logger.debug("all workers are alive.")
  257. def _try_get_next_batch(self):
  258. start_time = time.time()
  259. while True:
  260. self._check_workers()
  261. try:
  262. return self.batch_queue.get(timeout=1)
  263. except queue.Empty:
  264. logger.debug("batch queue empty!")
  265. waited_time = time.time() - start_time
  266. if self.timeout > 0:
  267. if waited_time > self.timeout:
  268. raise RuntimeError("get_next_batch timeout!")
  269. def _get_next_batch(self):
  270. batch_data = self._try_get_next_batch()
  271. return batch_data
  272. def _shutdown(self):
  273. with self.shutdown_flag.get_lock():
  274. self.shutdown_flag.value = 1
  275. if self.task_feeding_worker.is_alive():
  276. self.task_feeding_worker.terminate()
  277. self.task_feeding_worker.join()
  278. if self.data_collecting_worker.is_alive():
  279. self.data_collecting_worker.terminate()
  280. self.data_collecting_worker.join()
  281. for worker in self.workers:
  282. if worker.is_alive():
  283. worker.terminate()
  284. worker.join()
  285. for q in self.trans_data_queues:
  286. q.cancel_join_thread()
  287. q.close()
  288. for q in self.task_queues:
  289. q.cancel_join_thread()
  290. q.close()
  291. self.batch_queue.cancel_join_thread()
  292. self.batch_queue.close()
  293. def __del__(self):
  294. if self.__initialized:
  295. self._shutdown()
  296. class _BaseStreamDataLoaderIter:
  297. def __init__(self, loader):
  298. self.dataset = loader.dataset
  299. self.sampler = loader.sampler
  300. self.transform = loader.transform
  301. self.collator = loader.collator
  302. self.num_workers = loader.num_workers
  303. self.timeout = loader.timeout
  304. def _get_next_batch(self):
  305. raise NotImplementedError
  306. def __iter__(self):
  307. return self
  308. def __next__(self):
  309. return self._get_next_batch()
  310. class _SerialStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  311. def __init__(self, loader):
  312. super().__init__(loader)
  313. self.dataset_iter = iter(self.dataset)
  314. self.idx = 0
  315. self.data = None
  316. def _get_next_batch(self):
  317. ret = []
  318. while len(ret) != self.sampler.batch_size:
  319. if self.idx != 0:
  320. data = self.data
  321. else:
  322. try:
  323. timer = threading.Timer(self.timeout, thread.interrupt_main)
  324. timer.start()
  325. raw_data = next(self.dataset_iter)
  326. timer.cancel()
  327. except KeyboardInterrupt:
  328. raise RuntimeError("get_next_batch timeout!")
  329. except:
  330. timer.cancel()
  331. continue
  332. assert len(raw_data) == 2 and isinstance(
  333. raw_data[0], bool
  334. ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
  335. if not raw_data[0]:
  336. data = list((x,) for x in raw_data[1])
  337. else:
  338. data = raw_data[1]
  339. for idx in range(self.idx, len(data[0])):
  340. trans_data = self.transform.apply(tuple(e[idx] for e in data))
  341. ret.append(trans_data)
  342. if len(ret) == self.sampler.batch_size:
  343. if idx + 1 == len(data[0]):
  344. self.idx = 0
  345. self.data = None
  346. else:
  347. self.idx = idx
  348. self.data = data
  349. break
  350. return self.collator.apply(ret)
  351. class _ParallelStreamDataLoaderIter(_BaseStreamDataLoaderIter):
  352. __initialized = False
  353. def __init__(self, loader):
  354. super().__init__(loader)
  355. self.shutdown_flag = multiprocessing.Value("i", 0)
  356. self.raw_data_queues = [
  357. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  358. ]
  359. self.trans_data_queues = [
  360. multiprocessing.Queue(maxsize=1) for _ in range(self.num_workers)
  361. ]
  362. # shared-memory queue implemented by pyarrow plasma store
  363. from ._queue import PlasmaShmQueue
  364. self.batch_queue = PlasmaShmQueue(maxsize=2)
  365. self.recieve_worker = multiprocessing.Process(target=self._recieve, daemon=True)
  366. self.recieve_worker.start()
  367. self.transform_workers = []
  368. for worker_id in range(self.num_workers):
  369. worker = multiprocessing.Process(
  370. target=self._transform, args=(worker_id,), daemon=True
  371. )
  372. worker.start()
  373. self.transform_workers.append(worker)
  374. self.collect_worker = multiprocessing.Process(target=self._collect, daemon=True)
  375. self.collect_worker.start()
  376. self.__initialized = True
  377. def _recieve(self):
  378. dataset_iter = iter(self.dataset)
  379. cnt = -1
  380. while True:
  381. if self.shutdown_flag.value == 1:
  382. break
  383. raw_data = next(dataset_iter)
  384. assert len(raw_data) == 2 and isinstance(
  385. raw_data[0], bool
  386. ), "StreamDataset should provide a binary tuple, the first item indicates whether the data was batched."
  387. if not raw_data[0]:
  388. data = list((x,) for x in raw_data[1])
  389. else:
  390. data = raw_data[1]
  391. for idx in range(len(data[0])):
  392. while True:
  393. cnt += 1
  394. qid = cnt % self.num_workers
  395. try:
  396. self.raw_data_queues[qid].put(tuple(e[idx] for e in data))
  397. break
  398. except queue.Full:
  399. if self.shutdown_flag.value == 1:
  400. break
  401. logger.debug("raw data queue is full")
  402. def _transform(self, worker_id):
  403. while True:
  404. if self.shutdown_flag.value == 1:
  405. break
  406. try:
  407. data = self.raw_data_queues[worker_id].get(timeout=GLOBAL_TIMEOUT)
  408. except queue.Empty:
  409. continue
  410. trans_data = self.transform.apply(data)
  411. while True:
  412. try:
  413. self.trans_data_queues[worker_id].put(trans_data)
  414. break
  415. except queue.Full:
  416. if self.shutdown_flag.value == 1:
  417. break
  418. logger.debug("batch queue if full")
  419. def _collect(self):
  420. cnt = -1
  421. trans_items = []
  422. while True:
  423. if self.shutdown_flag.value == 1:
  424. break
  425. cnt += 1
  426. queue_id = cnt % self.num_workers
  427. try:
  428. trans_item = self.trans_data_queues[queue_id].get(
  429. timeout=GLOBAL_TIMEOUT
  430. )
  431. except queue.Empty:
  432. continue
  433. trans_items.append(trans_item)
  434. if len(trans_items) == self.sampler.batch_size:
  435. batch_data = self.collator.apply(trans_items)
  436. while True:
  437. try:
  438. self.batch_queue.put(batch_data, timeout=1)
  439. break
  440. except queue.Full:
  441. if self.shutdown_flag.value == 1:
  442. break
  443. logger.debug("batch queue is full")
  444. trans_items = []
  445. def _check_workers(self):
  446. if not self.collect_worker.is_alive():
  447. exitcode = self.collect_worker.exitcode
  448. if exitcode != 0:
  449. raise RuntimeError("collator worker died. {}".format(exitcode))
  450. for worker_id, worker in enumerate(self.transform_workers):
  451. if not worker.is_alive():
  452. exitcode = worker.exitcode
  453. if exitcode != 0:
  454. raise RuntimeError(
  455. "worker: {} died. {}".format(worker_id, exitcode)
  456. )
  457. def _try_get_next_batch(self):
  458. start_time = time.time()
  459. while True:
  460. self._check_workers()
  461. try:
  462. return self.batch_queue.get(timeout=1)
  463. except queue.Empty:
  464. logger.debug("batch queue empty!")
  465. waited_time = time.time() - start_time
  466. if self.timeout > 0 and waited_time > self.timeout:
  467. raise RuntimeError("get_next_batch timeout!")
  468. def _get_next_batch(self):
  469. batch_data = self._try_get_next_batch()
  470. return batch_data
  471. def _shutdown(self):
  472. with self.shutdown_flag.get_lock():
  473. self.shutdown_flag.value = 1
  474. if self.recieve_worker.is_alive():
  475. self.recieve_worker.terminate()
  476. self.recieve_worker.join()
  477. if self.collect_worker.is_alive():
  478. self.collect_worker.terminate()
  479. self.collect_worker.join()
  480. for worker in self.transform_workers:
  481. if worker.is_alive():
  482. worker.terminate()
  483. worker.join()
  484. for q in self.raw_data_queues:
  485. q.cancel_join_thread()
  486. q.close()
  487. for q in self.trans_data_queues:
  488. q.cancel_join_thread()
  489. q.close()
  490. self.batch_queue.cancel_join_thread()
  491. self.batch_queue.close()
  492. def __del__(self):
  493. if self.__initialized:
  494. self._shutdown()
  495. def _task_feeding_loop(
  496. indices_iter, task_queues, num_workers, divide, shutdown_flag, feed_batch_idx
  497. ):
  498. # Feed the indices into the task queues
  499. while True:
  500. if shutdown_flag.value == 1:
  501. break
  502. batch_idx = feed_batch_idx.value
  503. try:
  504. indices = next(indices_iter)
  505. except StopIteration:
  506. break
  507. if divide:
  508. # make sure all task_queues is ready for put
  509. while any([q.full() for q in task_queues]):
  510. if shutdown_flag.value == 1:
  511. return
  512. # divide into small pieces, feed to different workers.
  513. sub_num = math.ceil(len(indices) / num_workers)
  514. for worker_id in range(num_workers):
  515. sub_indices = indices[worker_id * sub_num : (worker_id + 1) * sub_num]
  516. task_queues[worker_id].put((batch_idx, sub_indices))
  517. else:
  518. # distribute tasks to different workers uniformly.
  519. target_id = batch_idx % num_workers
  520. while task_queues[target_id].full():
  521. if shutdown_flag.value == 1:
  522. return
  523. task_queues[target_id].put((batch_idx, indices))
  524. with feed_batch_idx.get_lock():
  525. feed_batch_idx.value += 1
  526. def _worker_loop(dataset, task_queue, trans_data_queue, transform, seed, shutdown_flag):
  527. # Get dataset items and do the transform
  528. random.seed(seed)
  529. np.random.seed(seed)
  530. while True:
  531. if shutdown_flag.value == 1:
  532. break
  533. try:
  534. batch_idx, indices = task_queue.get(timeout=GLOBAL_TIMEOUT)
  535. except queue.Empty:
  536. continue
  537. if len(indices) > 0:
  538. items = [dataset[idx] for idx in indices]
  539. trans_items = transform.apply_batch(items)
  540. else:
  541. # in case of incomplete last batch
  542. trans_items = ()
  543. while True:
  544. try:
  545. trans_data_queue.put((batch_idx, trans_items), timeout=1)
  546. break
  547. except queue.Full:
  548. if shutdown_flag.value == 1:
  549. break
  550. logger.debug("batch part queue is full!")
  551. def _data_gathering_loop(
  552. trans_data_queues,
  553. batch_queue,
  554. collator,
  555. length,
  556. num_workers,
  557. shutdown_flag,
  558. target_idx,
  559. ):
  560. # Gathering the small pieces of batch data into full batch data
  561. while True:
  562. if shutdown_flag.value == 1:
  563. break
  564. target_batch_idx = target_idx.value
  565. if target_batch_idx >= length:
  566. break
  567. full_trans_items = []
  568. for worker_id in range(num_workers):
  569. while True:
  570. try:
  571. batch_idx, trans_items = trans_data_queues[worker_id].get(
  572. timeout=GLOBAL_TIMEOUT
  573. )
  574. break
  575. except queue.Empty:
  576. if shutdown_flag.value == 1:
  577. break
  578. logger.debug(
  579. "worker:{} data queue get timeout! target batch idx:{}".format(
  580. worker_id, target_batch_idx
  581. )
  582. )
  583. if batch_idx != target_batch_idx:
  584. raise RuntimeError(
  585. "Unexperted batch_idx in data gathering loop. worker_id:{}.".format(
  586. worker_id
  587. )
  588. )
  589. else:
  590. full_trans_items.extend(trans_items)
  591. # Merge different parts into a batch.
  592. full_batch = collator.apply(full_trans_items)
  593. while True:
  594. try:
  595. batch_queue.put(full_batch, timeout=1)
  596. break
  597. except queue.Full:
  598. if shutdown_flag.value == 1:
  599. break
  600. logger.debug("batch queue is full!")
  601. with target_idx.get_lock():
  602. target_idx.value += 1
  603. batch_queue.disconnect_client()
  604. def _data_selecting_loop(
  605. trans_data_queues,
  606. batch_queue,
  607. collator,
  608. length,
  609. num_workers,
  610. shutdown_flag,
  611. target_idx,
  612. ):
  613. # Make sure that batch is generated exactly with the same order as generated indices
  614. while True:
  615. if shutdown_flag.value == 1:
  616. break
  617. target_batch_idx = target_idx.value
  618. if target_batch_idx >= length:
  619. break
  620. target_worker_id = target_batch_idx % num_workers
  621. while True:
  622. try:
  623. batch_idx, trans_items = trans_data_queues[target_worker_id].get(
  624. timeout=GLOBAL_TIMEOUT
  625. )
  626. batch_data = collator.apply(trans_items)
  627. break
  628. except queue.Empty:
  629. if shutdown_flag.value == 1:
  630. break
  631. logger.debug(
  632. "worker:{} data queue get timeout! target batch idx:{}".format(
  633. target_worker_id, target_batch_idx
  634. )
  635. )
  636. if batch_idx != target_batch_idx:
  637. raise RuntimeError(
  638. "batch_idx {} mismatch the target_batch_idx {}".format(
  639. batch_idx, target_batch_idx
  640. )
  641. )
  642. while True:
  643. try:
  644. batch_queue.put(batch_data, timeout=1)
  645. break
  646. except queue.Full:
  647. if shutdown_flag.value == 1:
  648. break
  649. logger.debug("batch queue is full!")
  650. with target_idx.get_lock():
  651. target_idx.value += 1
  652. batch_queue.disconnect_client()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台