| @@ -0,0 +1,34 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| { | |||||
| public abstract class DataAdapter | |||||
| { | |||||
| protected DataAdapterArgs args; | |||||
| protected IDatasetV2 dataset; | |||||
| public virtual bool CanHandle(Tensor x, Tensor y = null) | |||||
| => throw new NotImplementedException(); | |||||
| public virtual IDatasetV2 GetDataset() | |||||
| => dataset; | |||||
| public virtual int GetSize() | |||||
| => throw new NotImplementedException(""); | |||||
| public virtual (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||||
| { | |||||
| if (y.TensorShape.ndim == 1) | |||||
| y = array_ops.expand_dims(y, axis: -1); | |||||
| return (x, y); | |||||
| } | |||||
| public virtual bool ShouldRecreateIterator() | |||||
| { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -91,12 +91,14 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | public IEnumerable<(int, OwnedIterator)> enumerate_epochs() | ||||
| { | { | ||||
| using var ownedIterator = new OwnedIterator(_dataset); | |||||
| var data_iterator = new OwnedIterator(_dataset); | |||||
| foreach (var epoch in range(_initial_epoch, _epochs)) | foreach (var epoch in range(_initial_epoch, _epochs)) | ||||
| { | { | ||||
| if (_insufficient_data) | if (_insufficient_data) | ||||
| break; | break; | ||||
| yield return (epoch, ownedIterator); | |||||
| if (_adapter.ShouldRecreateIterator()) | |||||
| data_iterator = new OwnedIterator(_dataset); | |||||
| yield return (epoch, data_iterator); | |||||
| } | } | ||||
| } | } | ||||
| @@ -5,31 +5,15 @@ using Tensorflow.Keras.ArgsDefinition; | |||||
| namespace Tensorflow.Keras.Engine.DataAdapters | namespace Tensorflow.Keras.Engine.DataAdapters | ||||
| { | { | ||||
| public class DatasetAdapter : IDataAdapter | |||||
| public class DatasetAdapter : DataAdapter, IDataAdapter | |||||
| { | { | ||||
| DataAdapterArgs args; | |||||
| IDatasetV2 _dataset => args.Dataset; | |||||
| public DatasetAdapter(DataAdapterArgs args) | public DatasetAdapter(DataAdapterArgs args) | ||||
| { | { | ||||
| this.args = args; | this.args = args; | ||||
| dataset = args.Dataset; | |||||
| } | } | ||||
| public bool CanHandle(Tensor x, Tensor y = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public IDatasetV2 GetDataset() | |||||
| => _dataset; | |||||
| public int GetSize() | |||||
| public override int GetSize() | |||||
| => -1; | => -1; | ||||
| public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||||
| { | |||||
| if (y.TensorShape.ndim == 1) | |||||
| y = array_ops.expand_dims(y, axis: -1); | |||||
| return (x, y); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -17,5 +17,6 @@ | |||||
| IDatasetV2 GetDataset(); | IDatasetV2 GetDataset(); | ||||
| int GetSize(); | int GetSize(); | ||||
| (Tensor, Tensor) Expand1d(Tensor x, Tensor y); | (Tensor, Tensor) Expand1d(Tensor x, Tensor y); | ||||
| bool ShouldRecreateIterator(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -7,14 +7,12 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| /// <summary> | /// <summary> | ||||
| /// Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy. | /// Adapter that handles Tensor-like objects, e.g. EagerTensor and NumPy. | ||||
| /// </summary> | /// </summary> | ||||
| public class TensorLikeDataAdapter : IDataAdapter | |||||
| public class TensorLikeDataAdapter : DataAdapter, IDataAdapter | |||||
| { | { | ||||
| DataAdapterArgs args; | |||||
| int _size; | int _size; | ||||
| int _batch_size; | int _batch_size; | ||||
| int num_samples; | int num_samples; | ||||
| int num_full_batches; | int num_full_batches; | ||||
| IDatasetV2 _dataset; | |||||
| public TensorLikeDataAdapter(DataAdapterArgs args) | public TensorLikeDataAdapter(DataAdapterArgs args) | ||||
| { | { | ||||
| @@ -31,7 +29,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| indices_dataset = indices_dataset.repeat(); | indices_dataset = indices_dataset.repeat(); | ||||
| indices_dataset = indices_dataset.map(permutation).prefetch(1); | indices_dataset = indices_dataset.map(permutation).prefetch(1); | ||||
| indices_dataset = indices_dataset.flat_map(slice_batch_indices); | indices_dataset = indices_dataset.flat_map(slice_batch_indices); | ||||
| _dataset = slice_inputs(indices_dataset, args.X, args.Y); | |||||
| dataset = slice_inputs(indices_dataset, args.X, args.Y); | |||||
| } | } | ||||
| Tensor permutation(Tensor tensor) | Tensor permutation(Tensor tensor) | ||||
| @@ -73,26 +71,11 @@ namespace Tensorflow.Keras.Engine.DataAdapters | |||||
| return dataset; | return dataset; | ||||
| } | } | ||||
| public bool CanHandle(Tensor x, Tensor y = null) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| void _process_tensorlike() | |||||
| { | |||||
| } | |||||
| public IDatasetV2 GetDataset() | |||||
| => _dataset; | |||||
| public int GetSize() | |||||
| public override int GetSize() | |||||
| => _size; | => _size; | ||||
| public (Tensor, Tensor) Expand1d(Tensor x, Tensor y) | |||||
| void _process_tensorlike() | |||||
| { | { | ||||
| if (y.TensorShape.ndim == 1) | |||||
| y = array_ops.expand_dims(y, axis: -1); | |||||
| return (x, y); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -41,7 +41,7 @@ namespace Tensorflow.Keras.Utils | |||||
| } | } | ||||
| var wc = new WebClient(); | var wc = new WebClient(); | ||||
| Console.WriteLine($"Downloading {relativeFilePath}"); | |||||
| Console.WriteLine($"Downloading from {url}"); | |||||
| var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); | var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); | ||||
| while (!download.IsCompleted) | while (!download.IsCompleted) | ||||
| { | { | ||||
| @@ -49,7 +49,7 @@ namespace Tensorflow.Keras.Utils | |||||
| Console.Write("."); | Console.Write("."); | ||||
| } | } | ||||
| Console.WriteLine(""); | Console.WriteLine(""); | ||||
| Console.WriteLine($"Downloaded {relativeFilePath}"); | |||||
| Console.WriteLine($"Downloaded to {relativeFilePath}"); | |||||
| return true; | return true; | ||||
| } | } | ||||