From 3516079dbb262c3bfaf6e82e3d553e82528d824d Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 26 Jan 2020 18:17:55 -0600 Subject: [PATCH] Tensorflow Hub --- TensorFlow.NET.sln | 82 +++++++- src/TensorFlowNET.Core/Binding.Util.cs | 2 +- .../TensorFlow.Binding.csproj | 5 +- src/TensorFlowNET.Hub/DataSetBase.cs | 13 ++ src/TensorFlowNET.Hub/Datasets.cs | 46 +++++ src/TensorFlowNET.Hub/IDataSet.cs | 13 ++ src/TensorFlowNET.Hub/IModelLoader.cs | 14 ++ src/TensorFlowNET.Hub/MnistDataSet.cs | 88 +++++++++ src/TensorFlowNET.Hub/MnistModelLoader.cs | 184 ++++++++++++++++++ src/TensorFlowNET.Hub/ModelLoadSetting.cs | 20 ++ src/TensorFlowNET.Hub/README.md | 5 + src/TensorFlowNET.Hub/Tensorflow.Hub.csproj | 26 +++ src/TensorFlowNET.Hub/Utils.cs | 137 +++++++++++++ .../Hub/MnistModelLoaderTest.cs | 24 +++ test/TensorFlowNET.UnitTest/UnitTest.csproj | 1 + 15 files changed, 657 insertions(+), 3 deletions(-) create mode 100644 src/TensorFlowNET.Hub/DataSetBase.cs create mode 100644 src/TensorFlowNET.Hub/Datasets.cs create mode 100644 src/TensorFlowNET.Hub/IDataSet.cs create mode 100644 src/TensorFlowNET.Hub/IModelLoader.cs create mode 100644 src/TensorFlowNET.Hub/MnistDataSet.cs create mode 100644 src/TensorFlowNET.Hub/MnistModelLoader.cs create mode 100644 src/TensorFlowNET.Hub/ModelLoadSetting.cs create mode 100644 src/TensorFlowNET.Hub/README.md create mode 100644 src/TensorFlowNET.Hub/Tensorflow.Hub.csproj create mode 100644 src/TensorFlowNET.Hub/Utils.cs create mode 100644 test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 3caec27f..1330e37e 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -11,12 +11,20 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "UnitTest", "test\TensorFlow EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\TensorFlowNET.Keras\Tensorflow.Keras.csproj", "{6268B461-486A-460B-9B3C-86493CBBAAF7}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorFlowNET.Hub\Tensorflow.Hub.csproj", "{95B077C1-E21B-486F-8BDD-1C902FE687AB}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU Debug|x64 = Debug|x64 + Debug-Minimal|Any CPU = Debug-Minimal|Any CPU + Debug-Minimal|x64 = Debug-Minimal|x64 + Publish|Any CPU = Publish|Any CPU + Publish|x64 = Publish|x64 Release|Any CPU = Release|Any CPU Release|x64 = Release|x64 EndGlobalSection @@ -25,6 +33,14 @@ Global {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.Build.0 = Debug|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.ActiveCfg = Release|Any CPU + {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.Build.0 = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU @@ -33,6 +49,14 @@ Global {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|Any CPU + {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU + {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU + {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU + {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.Build.0 = Debug|Any CPU + {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Release|Any CPU + {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.ActiveCfg = Release|Any CPU + {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.Build.0 = Release|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|Any CPU @@ -41,6 +65,14 @@ Global {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.Build.0 = Debug|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.ActiveCfg = Release|Any CPU + {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.Build.0 = Release|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|Any CPU @@ -49,6 +81,14 @@ Global {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|Any CPU + {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU + {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU + {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU + {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|Any CPU + {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU + {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|Any CPU + {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|Any CPU @@ -57,10 +97,50 @@ Global {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|Any CPU + {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU + {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU + {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU + {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|Any CPU + {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU + {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|Any CPU + {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|Any CPU {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.ActiveCfg = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.Build.0 = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.Build.0 = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.ActiveCfg = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.Build.0 = Debug|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.Build.0 = Release|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.ActiveCfg = Release|Any CPU + {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.Build.0 = Release|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|x64.ActiveCfg = Debug|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|x64.Build.0 = Debug|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|Any CPU.ActiveCfg = Release|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|Any CPU.Build.0 = Release|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|x64.ActiveCfg = Release|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|x64.Build.0 = Release|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|Any CPU.Build.0 = Release|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|x64.ActiveCfg = Release|Any CPU + {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|x64.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index eaeefd73..9f51ce2d 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -88,7 +88,7 @@ namespace Tensorflow case ICollection arr: return arr.Count; case NDArray ndArray: - return ndArray.shape[0]; + return ndArray.ndim == 0 ? 1 : ndArray.shape[0]; case IEnumerable enumerable: return enumerable.OfType().Count(); } diff --git a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj index ea93f37d..633f179a 100644 --- a/src/TensorFlowNET.Core/TensorFlow.Binding.csproj +++ b/src/TensorFlowNET.Core/TensorFlow.Binding.csproj @@ -60,10 +60,13 @@ https://tensorflownet.readthedocs.io - + + + + diff --git a/src/TensorFlowNET.Hub/DataSetBase.cs b/src/TensorFlowNET.Hub/DataSetBase.cs new file mode 100644 index 00000000..dc47b1c8 --- /dev/null +++ b/src/TensorFlowNET.Hub/DataSetBase.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; +using NumSharp; + +namespace Tensorflow.Hub +{ + public abstract class DataSetBase : IDataSet + { + public NDArray Data { get; protected set; } + public NDArray Labels { get; protected set; } + } +} diff --git a/src/TensorFlowNET.Hub/Datasets.cs b/src/TensorFlowNET.Hub/Datasets.cs new file mode 100644 index 00000000..6c05efb6 --- /dev/null +++ b/src/TensorFlowNET.Hub/Datasets.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections.Generic; +using System.Text; +using NumSharp; + +namespace Tensorflow.Hub +{ + public class Datasets where TDataSet : IDataSet + { + public TDataSet Train { get; private set; } + + public TDataSet Validation { get; private set; } + + public TDataSet Test { get; private set; } + + public Datasets(TDataSet train, TDataSet validation, TDataSet test) + { + Train = train; + Validation = validation; + Test = test; + } + + public (NDArray, NDArray) Randomize(NDArray x, NDArray y) + { + var perm = np.random.permutation(y.shape[0]); + np.random.shuffle(perm); + return (x[perm], y[perm]); + } + + /// + /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) + /// + /// + /// + /// + /// + /// + public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) + { + var slice = new Slice(start, end); + var x_batch = x[slice]; + var y_batch = y[slice]; + return (x_batch, y_batch); + } + } +} diff --git a/src/TensorFlowNET.Hub/IDataSet.cs b/src/TensorFlowNET.Hub/IDataSet.cs new file mode 100644 index 00000000..f38a4217 --- /dev/null +++ b/src/TensorFlowNET.Hub/IDataSet.cs @@ -0,0 +1,13 @@ +using System; +using System.Collections.Generic; +using System.Text; +using NumSharp; + +namespace Tensorflow.Hub +{ + public interface IDataSet + { + NDArray Data { get; } + NDArray Labels { get; } + } +} diff --git a/src/TensorFlowNET.Hub/IModelLoader.cs b/src/TensorFlowNET.Hub/IModelLoader.cs new file mode 100644 index 00000000..530138af --- /dev/null +++ b/src/TensorFlowNET.Hub/IModelLoader.cs @@ -0,0 +1,14 @@ +using System; +using System.Threading.Tasks; +using System.Collections.Generic; +using System.Text; +using NumSharp; + +namespace Tensorflow.Hub +{ + public interface IModelLoader + where TDataSet : IDataSet + { + Task> LoadAsync(ModelLoadSetting setting); + } +} diff --git a/src/TensorFlowNET.Hub/MnistDataSet.cs b/src/TensorFlowNET.Hub/MnistDataSet.cs new file mode 100644 index 00000000..4cd9663b --- /dev/null +++ b/src/TensorFlowNET.Hub/MnistDataSet.cs @@ -0,0 +1,88 @@ +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Text; +using NumSharp; +using Tensorflow; + +namespace Tensorflow.Hub +{ + public class MnistDataSet : DataSetBase + { + public int NumOfExamples { get; private set; } + public int EpochsCompleted { get; private set; } + public int IndexInEpoch { get; private set; } + + public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape) + { + EpochsCompleted = 0; + IndexInEpoch = 0; + + NumOfExamples = images.shape[0]; + + images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); + images = images.astype(dataType); + // for debug np.multiply performance + var sw = new Stopwatch(); + sw.Start(); + images = np.multiply(images, 1.0f / 255.0f); + sw.Stop(); + Console.WriteLine($"{sw.ElapsedMilliseconds}ms"); + Data = images; + + labels = labels.astype(dataType); + Labels = labels; + } + + public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true) + { + if (IndexInEpoch >= NumOfExamples) + IndexInEpoch = 0; + + var start = IndexInEpoch; + // Shuffle for the first epoch + if(EpochsCompleted == 0 && start == 0 && shuffle) + { + var perm0 = np.arange(NumOfExamples); + np.random.shuffle(perm0); + Data = Data[perm0]; + Labels = Labels[perm0]; + } + + // Go to the next epoch + if (start + batch_size > NumOfExamples) + { + // Finished epoch + EpochsCompleted += 1; + + // Get the rest examples in this epoch + var rest_num_examples = NumOfExamples - start; + var images_rest_part = Data[np.arange(start, NumOfExamples)]; + var labels_rest_part = Labels[np.arange(start, NumOfExamples)]; + // Shuffle the data + if (shuffle) + { + var perm = np.arange(NumOfExamples); + np.random.shuffle(perm); + Data = Data[perm]; + Labels = Labels[perm]; + } + + start = 0; + IndexInEpoch = batch_size - rest_num_examples; + var end = IndexInEpoch; + var images_new_part = Data[np.arange(start, end)]; + var labels_new_part = Labels[np.arange(start, end)]; + + return (np.concatenate(new[] { images_rest_part, images_new_part }, axis: 0), + np.concatenate(new[] { labels_rest_part, labels_new_part }, axis: 0)); + } + else + { + IndexInEpoch += batch_size; + var end = IndexInEpoch; + return (Data[np.arange(start, end)], Labels[np.arange(start, end)]); + } + } + } +} diff --git a/src/TensorFlowNET.Hub/MnistModelLoader.cs b/src/TensorFlowNET.Hub/MnistModelLoader.cs new file mode 100644 index 00000000..4fdd69b6 --- /dev/null +++ b/src/TensorFlowNET.Hub/MnistModelLoader.cs @@ -0,0 +1,184 @@ +using System; +using System.Threading.Tasks; +using System.Collections.Generic; +using System.Text; +using System.IO; +using NumSharp; + +namespace Tensorflow.Hub +{ + public class MnistModelLoader : IModelLoader + { + private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; + private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; + private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; + private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; + private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; + + public static async Task> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) + { + var loader = new MnistModelLoader(); + + var setting = new ModelLoadSetting + { + TrainDir = trainDir, + OneHot = oneHot, + ShowProgressInConsole = showProgressInConsole + }; + + if (trainSize.HasValue) + setting.TrainSize = trainSize.Value; + + if (validationSize.HasValue) + setting.ValidationSize = validationSize.Value; + + if (testSize.HasValue) + setting.TestSize = testSize.Value; + + return await loader.LoadAsync(setting); + } + + public async Task> LoadAsync(ModelLoadSetting setting) + { + if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value) + throw new ArgumentException("Validation set should be smaller than training set"); + + var sourceUrl = setting.SourceUrl; + + if (string.IsNullOrEmpty(sourceUrl)) + sourceUrl = DEFAULT_SOURCE_URL; + + // load train images + await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); + + // load train labels + await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); + + // load test images + await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); + + // load test labels + await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) + .ShowProgressInConsole(setting.ShowProgressInConsole); + + var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); + + var end = trainImages.shape[0]; + + var validationSize = setting.ValidationSize; + + var validationImages = trainImages[np.arange(validationSize)]; + var validationLabels = trainLabels[np.arange(validationSize)]; + + trainImages = trainImages[np.arange(validationSize, end)]; + trainLabels = trainLabels[np.arange(validationSize, end)]; + + var dtype = setting.DataType; + var reshape = setting.ReShape; + + var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape); + var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape); + var test = new MnistDataSet(testImages, testLabels, dtype, reshape); + + return new Datasets(train, validation, test); + } + + private NDArray ExtractImages(string file, int? limit = null) + { + if (!Path.IsPathRooted(file)) + file = Path.Combine(AppContext.BaseDirectory, file); + + using (var bytestream = new FileStream(file, FileMode.Open)) + { + var magic = Read32(bytestream); + if (magic != 2051) + throw new Exception($"Invalid magic number {magic} in MNIST image file: {file}"); + + var num_images = Read32(bytestream); + num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); + + var rows = Read32(bytestream); + var cols = Read32(bytestream); + + var buf = new byte[rows * cols * num_images]; + + bytestream.Read(buf, 0, buf.Length); + + var data = np.frombuffer(buf, np.uint8); + data = data.reshape((int)num_images, (int)rows, (int)cols, 1); + + return data; + } + } + + private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null) + { + if (!Path.IsPathRooted(file)) + file = Path.Combine(AppContext.BaseDirectory, file); + + using (var bytestream = new FileStream(file, FileMode.Open)) + { + var magic = Read32(bytestream); + if (magic != 2049) + throw new Exception($"Invalid magic number {magic} in MNIST label file: {file}"); + + var num_items = Read32(bytestream); + num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit); + + var buf = new byte[num_items]; + + bytestream.Read(buf, 0, buf.Length); + + var labels = np.frombuffer(buf, np.uint8); + + if (one_hot) + return DenseToOneHot(labels, num_classes); + + return labels; + } + } + + private NDArray DenseToOneHot(NDArray labels_dense, int num_classes) + { + var num_labels = labels_dense.shape[0]; + var index_offset = np.arange(num_labels) * num_classes; + var labels_one_hot = np.zeros(num_labels, num_classes); + var labels = labels_dense.Data(); + for (int row = 0; row < num_labels; row++) + { + var col = labels[row]; + labels_one_hot.SetData(1.0, row, col); + } + + return labels_one_hot; + } + + private uint Read32(FileStream bytestream) + { + var buffer = new byte[sizeof(uint)]; + var count = bytestream.Read(buffer, 0, 4); + return np.frombuffer(buffer, ">u4").Data()[0]; + } + } +} diff --git a/src/TensorFlowNET.Hub/ModelLoadSetting.cs b/src/TensorFlowNET.Hub/ModelLoadSetting.cs new file mode 100644 index 00000000..89e46748 --- /dev/null +++ b/src/TensorFlowNET.Hub/ModelLoadSetting.cs @@ -0,0 +1,20 @@ +using System; +using System.Collections.Generic; +using System.Text; +using NumSharp; + +namespace Tensorflow.Hub +{ + public class ModelLoadSetting + { + public string TrainDir { get; set; } + public bool OneHot { get; set; } + public Type DataType { get; set; } = typeof(float); + public bool ReShape { get; set; } + public int ValidationSize { get; set; } = 5000; + public int? TrainSize { get; set; } + public int? TestSize { get; set; } + public string SourceUrl { get; set; } + public bool ShowProgressInConsole { get; set; } + } +} diff --git a/src/TensorFlowNET.Hub/README.md b/src/TensorFlowNET.Hub/README.md new file mode 100644 index 00000000..156b263d --- /dev/null +++ b/src/TensorFlowNET.Hub/README.md @@ -0,0 +1,5 @@ +## TensorFlow Hub + +TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models. In particular, it provides **modules**, which are pre-trained pieces of TensorFlow models that can be reused on new tasks. + +https://github.com/tensorflow/hub \ No newline at end of file diff --git a/src/TensorFlowNET.Hub/Tensorflow.Hub.csproj b/src/TensorFlowNET.Hub/Tensorflow.Hub.csproj new file mode 100644 index 00000000..7f5191a6 --- /dev/null +++ b/src/TensorFlowNET.Hub/Tensorflow.Hub.csproj @@ -0,0 +1,26 @@ + + + Tensorflow.Hub + netstandard2.0 + 0.0.5 + Kerry Jiang, Haiping Chen + SciSharp STACK + Apache 2.0 + https://github.com/SciSharp/TensorFlow.NET + git + http://scisharpstack.org + TensorFlow, SciSharp, MachineLearning + TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models. + SciSharp.TensorFlowHub + true + Fix GetNextBatch() bug. + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + TensorFlow.Hub + + + DEBUG;TRACE + + + + + \ No newline at end of file diff --git a/src/TensorFlowNET.Hub/Utils.cs b/src/TensorFlowNET.Hub/Utils.cs new file mode 100644 index 00000000..5b06aaad --- /dev/null +++ b/src/TensorFlowNET.Hub/Utils.cs @@ -0,0 +1,137 @@ +using System; +using System.IO; +using System.IO.Compression; +using System.Collections.Generic; +using System.Net; +using System.Text; +using System.Threading; +using System.Threading.Tasks; + +namespace Tensorflow.Hub +{ + public static class Utils + { + public static async Task DownloadAsync(this IModelLoader modelLoader, string url, string saveTo) + where TDataSet : IDataSet + { + var dir = Path.GetDirectoryName(saveTo); + var fileName = Path.GetFileName(saveTo); + await modelLoader.DownloadAsync(url, dir, fileName); + } + + public static async Task DownloadAsync(this IModelLoader modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false) + where TDataSet : IDataSet + { + if (!Path.IsPathRooted(dirSaveTo)) + dirSaveTo = Path.Combine(AppContext.BaseDirectory, dirSaveTo); + + var fileSaveTo = Path.Combine(dirSaveTo, fileName); + + if (showProgressInConsole) + { + Console.WriteLine($"Downloading {fileName}"); + } + + if (File.Exists(fileSaveTo)) + { + if (showProgressInConsole) + { + Console.WriteLine($"The file {fileName} already exists"); + } + + return; + } + + Directory.CreateDirectory(dirSaveTo); + + using (var wc = new WebClient()) + { + await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false); + } + + } + + public static async Task UnzipAsync(this IModelLoader modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false) + where TDataSet : IDataSet + { + if (!Path.IsPathRooted(saveTo)) + saveTo = Path.Combine(AppContext.BaseDirectory, saveTo); + + Directory.CreateDirectory(saveTo); + + if (!Path.IsPathRooted(zipFile)) + zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); + + var destFileName = Path.GetFileNameWithoutExtension(zipFile); + var destFilePath = Path.Combine(saveTo, destFileName); + + if (showProgressInConsole) + Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); + + if (File.Exists(destFilePath)) + { + if (showProgressInConsole) + Console.WriteLine($"The file {destFileName} already exists"); + } + + using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) + { + using (var destStream = File.Create(destFilePath)) + { + await unzipStream.CopyToAsync(destStream).ConfigureAwait(false); + await destStream.FlushAsync().ConfigureAwait(false); + destStream.Close(); + } + + unzipStream.Close(); + } + } + + public static async Task ShowProgressInConsole(this Task task, bool enable) + { + if (!enable) + { + await task; + return; + } + + var cts = new CancellationTokenSource(); + + var showProgressTask = ShowProgressInConsole(cts); + + try + { + await task; + } + finally + { + cts.Cancel(); + } + + await showProgressTask; + Console.WriteLine("Done."); + } + + private static async Task ShowProgressInConsole(CancellationTokenSource cts) + { + var cols = 0; + + await Task.Delay(100); + + while (!cts.IsCancellationRequested) + { + await Task.Delay(100); + Console.Write("."); + cols++; + + if (cols % 50 == 0) + { + Console.WriteLine(); + } + } + + if (cols > 0) + Console.WriteLine(); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs b/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs new file mode 100644 index 00000000..26dfd3b6 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Hub/MnistModelLoaderTest.cs @@ -0,0 +1,24 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System.Threading.Tasks; +using Tensorflow.Hub; + +namespace UnitTest +{ + [TestClass] + public class MnistModelLoaderTest + { + [TestMethod] + public async Task TestLoad() + { + var loader = new MnistModelLoader(); + var result = await loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = true, + ValidationSize = 5000, + }); + + Assert.IsNotNull(result); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/UnitTest.csproj b/test/TensorFlowNET.UnitTest/UnitTest.csproj index 861522f8..cff48481 100644 --- a/test/TensorFlowNET.UnitTest/UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/UnitTest.csproj @@ -37,6 +37,7 @@ +