#! /usr/bin/python # -*- coding: utf-8 -*- import tensorflow as tf __all__ = [ 'Apply', 'Batch', 'Concat', 'CsvDataset', 'Filter', 'Flat_map', 'FromGenerator', 'FromSlices', 'Map', 'Prefetch', 'Repeat', 'Shuffle', 'Skip', 'Take', 'TextFlieDataset', 'TFRecordDataset', 'Zip', 'Dataloader', ] def Apply(dataset, transformation_func): """Applies a transformation function to this dataset. `apply` enables chaining of custom `Dataset` transformations, which are represented as functions that take one `Dataset` argument and return a transformed `Dataset`. >>> dataset = tf.data.Dataset.range(100) >>> def dataset_fn(dataset): ... return dataset.filter(lambda x: x < 5) >>> dataset = dataset.apply(dataset_fn) >>> list(dataset.as_numpy_iterator()) [0, 1, 2, 3, 4] Args: transformation_func: A function that takes one `Dataset` argument and returns a `Dataset`. Returns: Dataset: The `Dataset` returned by applying `transformation_func` to this dataset. """ return dataset.apply(transformation_func) def Batch(dataset, batch_size, drop_remainder=False): ''' Parameters ---------- dataset batch_size drop_remainder Returns ------- ''' return dataset.batch(batch_size=batch_size, drop_remainder=drop_remainder) def Concat(dataset_1, dataset_2): return dataset_1.concatenate(dataset_2) def CsvDataset( file_pattern, batch_size=1, column_names=None, column_defaults=None, label_name=None, select_columns=None, field_delim=',', use_quote_delim=True, na_value='', header=True, num_epochs=None, shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=None, num_parallel_reads=None, sloppy=False, num_rows_for_inference=100, compression_type=None, ignore_errors=False, numples_samples=None, num_shards=None, shard_id=None, cache=None ): """Reads CSV files into a dataset. Reads CSV files into a dataset, where each element is a (features, labels) tuple that corresponds to a batch of CSV rows. The features dictionary maps feature column names to `Tensor`s containing the corresponding feature data, and labels is a `Tensor` containing the batch's label data. """ return tf.data.experimental.make_csv_dataset( file_pattern, batch_size, column_names=None, column_defaults=None, label_name=None, select_columns=None, field_delim=',', use_quote_delim=True, na_value='', header=True, num_epochs=None, shuffle=True, shuffle_buffer_size=10000, shuffle_seed=None, prefetch_buffer_size=None, num_parallel_reads=None, sloppy=False, num_rows_for_inference=100, compression_type=None, ignore_errors=False ) def Filter(dataset, predicate): ''' Filters this dataset according to predicate. Parameters ---------- dataset : A dataset predicate : A function mapping a dataset element to a boolean. Returns : The Dataset containing the elements of this dataset for which predicate is True. ------- ''' return dataset.filter(predicate) def Flat_map(dataset, map_func): ''' Maps map_func across this dataset and flattens the result. Parameters ---------- dataset: A dataset map_func A function mapping a dataset element to a dataset. Returns A Dataset. ------- ''' return dataset.flat_map(map_func) def FromGenerator( generator, output_types, output_shapes=None, args=None, column_names=None, column_types=None, schema=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, python_multiprocessing=True ): """Creates a `Dataset` whose elements are generated by `generator`. generator: A callable object """ return tf.data.Dataset.from_generator(generator, output_types, output_shapes=output_shapes, args=args) def FromSlices( tensor, column_names=None, num_samples=None, num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None ): return tf.data.Dataset.from_tensor_slices(tensor) def Map( dataset, map_func, num_parallel_calls=None, input_columns=None, output_columns=None, column_order=None, num_parallel_workers=None, python_multiprocessing=False, cache=None, callbacks=None ): """ Maps map_func across the elements of this dataset. Parameters ---------- dataset : DataFlow input DataFlow map_func : function A function mapping a dataset element to another dataset element. num_parallel_calls Returns ------- """ return dataset.map(map_func, num_parallel_calls=num_parallel_calls) def Prefetch(dataset, buffer_size): ''' Creates a Dataset that prefetches elements from this dataset. Parameters ---------- dataset: Dataflow A dataset buffer_size : A tf.int64 scalar tf.Tensor, representing the maximum number of elements that will be buffered when prefetching. Returns A Dataset ------- ''' return dataset.prefetch(buffer_size=buffer_size) def Repeat(dataset, count=None): return dataset.repeat(count=count) def Shuffle(dataset, buffer_size, seed=None, reshuffle_each_iteration=None): return dataset.shuffle(buffer_size, seed=seed, reshuffle_each_iteration=reshuffle_each_iteration) def Skip(dataset, count): ''' Creates a Dataset that skips count elements from this dataset. Parameters ---------- dataset: A dataset count: A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be skipped to form the new dataset. If count is greater than the size of this dataset, the new dataset will contain no elements. If count is -1, skips the entire dataset. Returns ------- ''' return dataset.skip(count) def Take(dataset, count): ''' Creates a Dataset with at most count elements from this dataset. Parameters ---------- dataset: A dataset count: A tf.int64 scalar tf.Tensor, representing the number of elements of this dataset that should be taken to form the new dataset. If count is -1, or if count is greater than the size of this dataset, the new dataset will contain all elements of this dataset. Returns ------- ''' return dataset.take(count) def TextFlieDataset( filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, num_samples=None, shuffle=None, num_shards=None, shard_id=None, cache=None ): return tf.data.TextLineDataset(filenames, compression_type, buffer_size, num_parallel_reads) def TFRecordDataset( filenames, compression_type=None, buffer_size=None, num_parallel_reads=None, schema=None, columns_list=None, num_samples=None, shuffle=None, num_shards=None, shard_id=None, shard_equal_rows=False, cache=None ): return tf.data.TFRecordDataset(filenames, compression_type, buffer_size, num_parallel_reads) def Zip(datasets): ''' Creates a Dataset by zipping together the given datasets. Parameters ---------- datasets: A tuple of datasets to be zipped together. Returns ------- ''' return tf.data.Dataset.zip(datasets) def Dataloader(dataset, batch_size, shuffle=False, drop_last=False, prefetch=0, shuffle_buffer_size=1024): if shuffle: dataset = Shuffle(dataset, buffer_size=shuffle_buffer_size, reshuffle_each_iteration=True) dataset = Batch(dataset, batch_size=batch_size, drop_remainder=drop_last) dataset = Prefetch(dataset, buffer_size=prefetch) return dataset