You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensorflow_data.py 7.5 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorflow as tf
  4. __all__ = [
  5. 'Apply',
  6. 'Batch',
  7. 'Concat',
  8. 'CsvDataset',
  9. 'Filter',
  10. 'Flat_map',
  11. 'FromGenerator',
  12. 'FromSlices',
  13. 'Map',
  14. 'Prefetch',
  15. 'Repeat',
  16. 'Shuffle',
  17. 'Skip',
  18. 'Take',
  19. 'TextFlieDataset',
  20. 'TFRecordDataset',
  21. 'Zip',
  22. ]
  23. def Apply(dataset, transformation_func):
  24. """Applies a transformation function to this dataset.
  25. `apply` enables chaining of custom `Dataset` transformations, which are
  26. represented as functions that take one `Dataset` argument and return a
  27. transformed `Dataset`.
  28. >>> dataset = tf.data.Dataset.range(100)
  29. >>> def dataset_fn(dataset):
  30. ... return dataset.filter(lambda x: x < 5)
  31. >>> dataset = dataset.apply(dataset_fn)
  32. >>> list(dataset.as_numpy_iterator())
  33. [0, 1, 2, 3, 4]
  34. Args:
  35. transformation_func: A function that takes one `Dataset` argument and
  36. returns a `Dataset`.
  37. Returns:
  38. Dataset: The `Dataset` returned by applying `transformation_func` to this
  39. dataset.
  40. """
  41. return dataset.apply(transformation_func)
  42. def Batch(dataset, batch_size, drop_remainder=False):
  43. '''
  44. Parameters
  45. ----------
  46. dataset
  47. batch_size
  48. drop_remainder
  49. Returns
  50. -------
  51. '''
  52. return dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder)
  53. def Concat(dataset_1, dataset_2):
  54. return dataset_1.concatenate(dataset_2)
  55. def CsvDataset(
  56. file_pattern, batch_size=1, column_names=None, column_defaults=None, label_name=None, select_columns=None,
  57. field_delim=',', use_quote_delim=True, na_value='', header=True, num_epochs=None, shuffle=True,
  58. shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=None, num_parallel_reads=None, sloppy=False,
  59. num_rows_for_inference=100, compression_type=None, ignore_errors=False, numples_samples=None, num_shards=None,
  60. shard_id=None, cache=None
  61. ):
  62. """Reads CSV files into a dataset.
  63. Reads CSV files into a dataset, where each element is a (features, labels)
  64. tuple that corresponds to a batch of CSV rows. The features dictionary
  65. maps feature column names to `Tensor`s containing the corresponding
  66. feature data, and labels is a `Tensor` containing the batch's label data.
  67. """
  68. return tf.data.experimental.make_csv_dataset(
  69. file_pattern, batch_size, column_names=None, column_defaults=None, label_name=None, select_columns=None,
  70. field_delim=',', use_quote_delim=True, na_value='', header=True, num_epochs=None, shuffle=True,
  71. shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=None, num_parallel_reads=None, sloppy=False,
  72. num_rows_for_inference=100, compression_type=None, ignore_errors=False
  73. )
  74. def Filter(dataset, predicate):
  75. '''
  76. Filters this dataset according to predicate.
  77. Parameters
  78. ----------
  79. dataset :
  80. A dataset
  81. predicate :
  82. A function mapping a dataset element to a boolean.
  83. Returns :
  84. The Dataset containing the elements of this dataset for which predicate is True.
  85. -------
  86. '''
  87. return dataset.filter(predicate)
  88. def Flat_map(dataset, map_func):
  89. '''
  90. Maps map_func across this dataset and flattens the result.
  91. Parameters
  92. ----------
  93. dataset:
  94. A dataset
  95. map_func
  96. A function mapping a dataset element to a dataset.
  97. Returns
  98. A Dataset.
  99. -------
  100. '''
  101. return dataset.flat_map(map_func)
  102. def FromGenerator(
  103. generator, output_types, output_shapes=None, args=None, column_names=None, column_types=None, schema=None,
  104. num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None,
  105. python_multiprocessing=True
  106. ):
  107. """Creates a `Dataset` whose elements are generated by `generator`.
  108. generator:
  109. A callable object
  110. """
  111. return tf.data.Dataset.from_generator(generator, output_types, output_shapes=output_shapes, args=args)
  112. def FromSlices(
  113. tensor, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None,
  114. shard_id=None
  115. ):
  116. return tf.data.Dataset.from_tensor_slices(tensor)
  117. def Map(
  118. dataset, map_func, num_parallel_calls=None, input_columns=None, output_columns=None, column_order=None,
  119. num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None
  120. ):
  121. """ Maps map_func across the elements of this dataset.
  122. Parameters
  123. ----------
  124. dataset : DataFlow
  125. input DataFlow
  126. map_func : function
  127. A function mapping a dataset element to another dataset element.
  128. num_parallel_calls
  129. Returns
  130. -------
  131. """
  132. return dataset.map(map_func, num_parallel_calls=num_parallel_calls)
  133. def Prefetch(dataset, buffer_size):
  134. '''
  135. Creates a Dataset that prefetches elements from this dataset.
  136. Parameters
  137. ----------
  138. dataset: Dataflow
  139. A dataset
  140. buffer_size :
  141. A tf.int64 scalar tf.Tensor, representing the maximum number of elements that will be buffered when prefetching.
  142. Returns
  143. A Dataset
  144. -------
  145. '''
  146. return dataset.prefetch(buffer_size=buffer_size)
  147. def Repeat(dataset, count=None):
  148. return dataset.repeat(count=count)
  149. def Shuffle(dataset, buffer_size, seed=None, reshuffle_each_iteration=None):
  150. return dataset.shuffle(buffer_size, seed=seed, reshuffle_each_iteration=reshuffle_each_iteration)
  151. def Skip(dataset, count):
  152. '''
  153. Creates a Dataset that skips count elements from this dataset.
  154. Parameters
  155. ----------
  156. dataset:
  157. A dataset
  158. count:
  159. A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be skipped to form the new dataset.
  160. If count is greater than the size of this dataset, the new dataset will contain no elements.
  161. If count is -1, skips the entire dataset.
  162. Returns
  163. -------
  164. '''
  165. return dataset.skip(count)
  166. def Take(dataset, count):
  167. '''
  168. Creates a Dataset with at most count elements from this dataset.
  169. Parameters
  170. ----------
  171. dataset:
  172. A dataset
  173. count:
  174. A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be taken to form the new dataset.
  175. If count is -1, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset.
  176. Returns
  177. -------
  178. '''
  179. return dataset.take(count)
  180. def TextFlieDataset(
  181. filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, num_samples=None, shuffle=None,
  182. num_shards=None, shard_id=None, cache=None
  183. ):
  184. return tf.data.TextLineDataset(filenames, compression_type, buffer_size, num_parallel_reads)
  185. def TFRecordDataset(
  186. filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, schema=None, columns_list=None,
  187. num_samples=None, shuffle=None, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None
  188. ):
  189. return tf.data.TFRecordDataset(filenames, compression_type, buffer_size, num_parallel_reads)
  190. def Zip(datasets):
  191. '''
  192. Creates a Dataset by zipping together the given datasets.
  193. Parameters
  194. ----------
  195. datasets:
  196. A tuple of datasets to be zipped together.
  197. Returns
  198. -------
  199. '''
  200. return tf.data.Dataset.zip(datasets)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.