You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_data.py 7.9 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorflow as tf
  4. __all__ = [
  5. 'Apply',
  6. 'Batch',
  7. 'Concat',
  8. 'CsvDataset',
  9. 'Filter',
  10. 'Flat_map',
  11. 'FromGenerator',
  12. 'FromSlices',
  13. 'Map',
  14. 'Prefetch',
  15. 'Repeat',
  16. 'Shuffle',
  17. 'Skip',
  18. 'Take',
  19. 'TextFlieDataset',
  20. 'TFRecordDataset',
  21. 'Zip',
  22. 'Dataloader',
  23. ]
  24. def Apply(dataset, transformation_func):
  25. """Applies a transformation function to this dataset.
  26. `apply` enables chaining of custom `Dataset` transformations, which are
  27. represented as functions that take one `Dataset` argument and return a
  28. transformed `Dataset`.
  29. >>> dataset = tf.data.Dataset.range(100)
  30. >>> def dataset_fn(dataset):
  31. ... return dataset.filter(lambda x: x < 5)
  32. >>> dataset = dataset.apply(dataset_fn)
  33. >>> list(dataset.as_numpy_iterator())
  34. [0, 1, 2, 3, 4]
  35. Args:
  36. transformation_func: A function that takes one `Dataset` argument and
  37. returns a `Dataset`.
  38. Returns:
  39. Dataset: The `Dataset` returned by applying `transformation_func` to this
  40. dataset.
  41. """
  42. return dataset.apply(transformation_func)
  43. def Batch(dataset, batch_size, drop_remainder=False):
  44. '''
  45. Parameters
  46. ----------
  47. dataset
  48. batch_size
  49. drop_remainder
  50. Returns
  51. -------
  52. '''
  53. return dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)
  54. def Concat(dataset_1, dataset_2):
  55. return dataset_1.concatenate(dataset_2)
  56. def CsvDataset(
  57. file_pattern, batch_size=1, column_names=None, column_defaults=None, label_name=None, select_columns=None,
  58. field_delim=',', use_quote_delim=True, na_value='', header=True, num_epochs=None, shuffle=True,
  59. shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=None, num_parallel_reads=None, sloppy=False,
  60. num_rows_for_inference=100, compression_type=None, ignore_errors=False, numples_samples=None, num_shards=None,
  61. shard_id=None, cache=None
  62. ):
  63. """Reads CSV files into a dataset.
  64. Reads CSV files into a dataset, where each element is a (features, labels)
  65. tuple that corresponds to a batch of CSV rows. The features dictionary
  66. maps feature column names to `Tensor`s containing the corresponding
  67. feature data, and labels is a `Tensor` containing the batch's label data.
  68. """
  69. return tf.data.experimental.make_csv_dataset(
  70. file_pattern, batch_size, column_names=None, column_defaults=None, label_name=None, select_columns=None,
  71. field_delim=',', use_quote_delim=True, na_value='', header=True, num_epochs=None, shuffle=True,
  72. shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=None, num_parallel_reads=None, sloppy=False,
  73. num_rows_for_inference=100, compression_type=None, ignore_errors=False
  74. )
  75. def Filter(dataset, predicate):
  76. '''
  77. Filters this dataset according to predicate.
  78. Parameters
  79. ----------
  80. dataset :
  81. A dataset
  82. predicate :
  83. A function mapping a dataset element to a boolean.
  84. Returns :
  85. The Dataset containing the elements of this dataset for which predicate is True.
  86. -------
  87. '''
  88. return dataset.filter(predicate)
  89. def Flat_map(dataset, map_func):
  90. '''
  91. Maps map_func across this dataset and flattens the result.
  92. Parameters
  93. ----------
  94. dataset:
  95. A dataset
  96. map_func
  97. A function mapping a dataset element to a dataset.
  98. Returns
  99. A Dataset.
  100. -------
  101. '''
  102. return dataset.flat_map(map_func)
  103. def FromGenerator(
  104. generator, output_types, output_shapes=None, args=None, column_names=None, column_types=None, schema=None,
  105. num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
  106. python_multiprocessing=True
  107. ):
  108. """Creates a `Dataset` whose elements are generated by `generator`.
  109. generator:
  110. A callable object
  111. """
  112. return tf.data.Dataset.from_generator(generator, output_types, output_shapes=output_shapes, args=args)
  113. def FromSlices(
  114. tensor, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None,
  115. shard_id=None
  116. ):
  117. return tf.data.Dataset.from_tensor_slices(tensor)
  118. def Map(
  119. dataset, map_func, num_parallel_calls=None, input_columns=None, output_columns=None, column_order=None,
  120. num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None
  121. ):
  122. """ Maps map_func across the elements of this dataset.
  123. Parameters
  124. ----------
  125. dataset : DataFlow
  126. input DataFlow
  127. map_func : function
  128. A function mapping a dataset element to another dataset element.
  129. num_parallel_calls
  130. Returns
  131. -------
  132. """
  133. return dataset.map(map_func, num_parallel_calls=num_parallel_calls)
  134. def Prefetch(dataset, buffer_size):
  135. '''
  136. Creates a Dataset that prefetches elements from this dataset.
  137. Parameters
  138. ----------
  139. dataset: Dataflow
  140. A dataset
  141. buffer_size :
  142. A tf.int64 scalar tf.Tensor, representing the maximum number of elements that will be buffered when prefetching.
  143. Returns
  144. A Dataset
  145. -------
  146. '''
  147. return dataset.prefetch(buffer_size=buffer_size)
  148. def Repeat(dataset, count=None):
  149. return dataset.repeat(count=count)
  150. def Shuffle(dataset, buffer_size, seed=None, reshuffle_each_iteration=None):
  151. return dataset.shuffle(buffer_size, seed=seed, reshuffle_each_iteration=reshuffle_each_iteration)
  152. def Skip(dataset, count):
  153. '''
  154. Creates a Dataset that skips count elements from this dataset.
  155. Parameters
  156. ----------
  157. dataset:
  158. A dataset
  159. count:
  160. A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be skipped to form the new dataset.
  161. If count is greater than the size of this dataset, the new dataset will contain no elements.
  162. If count is -1, skips the entire dataset.
  163. Returns
  164. -------
  165. '''
  166. return dataset.skip(count)
  167. def Take(dataset, count):
  168. '''
  169. Creates a Dataset with at most count elements from this dataset.
  170. Parameters
  171. ----------
  172. dataset:
  173. A dataset
  174. count:
  175. A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be taken to form the new dataset.
  176. If count is -1, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset.
  177. Returns
  178. -------
  179. '''
  180. return dataset.take(count)
  181. def TextFlieDataset(
  182. filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, num_samples=None, shuffle=None,
  183. num_shards=None, shard_id=None, cache=None
  184. ):
  185. return tf.data.TextLineDataset(filenames, compression_type, buffer_size, num_parallel_reads)
  186. def TFRecordDataset(
  187. filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, schema=None, columns_list=None,
  188. num_samples=None, shuffle=None, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None
  189. ):
  190. return tf.data.TFRecordDataset(filenames, compression_type, buffer_size, num_parallel_reads)
  191. def Zip(datasets):
  192. '''
  193. Creates a Dataset by zipping together the given datasets.
  194. Parameters
  195. ----------
  196. datasets:
  197. A tuple of datasets to be zipped together.
  198. Returns
  199. -------
  200. '''
  201. return tf.data.Dataset.zip(datasets)
  202. def Dataloader(dataset, batch_size, shuffle=False, drop_last=False, prefetch=0, shuffle_buffer_size=1024):
  203. if shuffle:
  204. dataset = Shuffle(dataset, buffer_size=shuffle_buffer_size, reshuffle_each_iteration=True)
  205. dataset = Batch(dataset, batch_size=batch_size, drop_remainder=drop_last)
  206. dataset = Prefetch(dataset, buffer_size=prefetch)
  207. return dataset

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.