| @@ -11,12 +11,20 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "UnitTest", "test\TensorFlow | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras", "src\TensorFlowNET.Keras\Tensorflow.Keras.csproj", "{6268B461-486A-460B-9B3C-86493CBBAAF7}" | |||
| EndProject | |||
| Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Keras.UnitTest", "test\Tensorflow.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj", "{EB92DD90-6346-41FB-B967-2B33A860AD98}" | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tensorflow.Hub", "src\TensorFlowNET.Hub\Tensorflow.Hub.csproj", "{95B077C1-E21B-486F-8BDD-1C902FE687AB}" | |||
| EndProject | |||
| Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}" | |||
| EndProject | |||
| Global | |||
| GlobalSection(SolutionConfigurationPlatforms) = preSolution | |||
| Debug|Any CPU = Debug|Any CPU | |||
| Debug|x64 = Debug|x64 | |||
| Debug-Minimal|Any CPU = Debug-Minimal|Any CPU | |||
| Debug-Minimal|x64 = Debug-Minimal|x64 | |||
| Publish|Any CPU = Publish|Any CPU | |||
| Publish|x64 = Publish|x64 | |||
| Release|Any CPU = Release|Any CPU | |||
| Release|x64 = Release|x64 | |||
| EndGlobalSection | |||
| @@ -25,6 +33,14 @@ Global | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.ActiveCfg = Release|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Publish|x64.Build.0 = Release|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|x64.ActiveCfg = Release|Any CPU | |||
| @@ -33,6 +49,14 @@ Global | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.ActiveCfg = Release|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Publish|x64.Build.0 = Release|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {3A6EB896-604F-4E25-B677-B8103BCF3D2E}.Release|x64.ActiveCfg = Release|Any CPU | |||
| @@ -41,6 +65,14 @@ Global | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.ActiveCfg = Release|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Publish|x64.Build.0 = Release|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {23C28035-2FCE-41F3-9A12-E73CE8A5AE32}.Release|x64.ActiveCfg = Release|Any CPU | |||
| @@ -49,6 +81,14 @@ Global | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.ActiveCfg = Release|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Publish|x64.Build.0 = Release|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {6268B461-486A-460B-9B3C-86493CBBAAF7}.Release|x64.ActiveCfg = Release|Any CPU | |||
| @@ -57,10 +97,50 @@ Global | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.ActiveCfg = Release|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Publish|x64.Build.0 = Release|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.ActiveCfg = Release|Any CPU | |||
| {EB92DD90-6346-41FB-B967-2B33A860AD98}.Release|x64.Build.0 = Release|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|Any CPU.Build.0 = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.ActiveCfg = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Publish|x64.Build.0 = Debug|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.ActiveCfg = Release|Any CPU | |||
| {95B077C1-E21B-486F-8BDD-1C902FE687AB}.Release|x64.Build.0 = Release|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|Any CPU.Build.0 = Debug|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|x64.ActiveCfg = Debug|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug|x64.Build.0 = Debug|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|Any CPU.ActiveCfg = Debug|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|Any CPU.Build.0 = Debug|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|x64.ActiveCfg = Debug|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Debug-Minimal|x64.Build.0 = Debug|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|Any CPU.ActiveCfg = Release|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|Any CPU.Build.0 = Release|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|x64.ActiveCfg = Release|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Publish|x64.Build.0 = Release|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|Any CPU.ActiveCfg = Release|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|Any CPU.Build.0 = Release|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|x64.ActiveCfg = Release|Any CPU | |||
| {9A9E06BA-9E3E-4FA1-9A8F-9DB596A1A5CB}.Release|x64.Build.0 = Release|Any CPU | |||
| EndGlobalSection | |||
| GlobalSection(SolutionProperties) = preSolution | |||
| HideSolutionNode = FALSE | |||
| @@ -88,7 +88,7 @@ namespace Tensorflow | |||
| case ICollection arr: | |||
| return arr.Count; | |||
| case NDArray ndArray: | |||
| return ndArray.shape[0]; | |||
| return ndArray.ndim == 0 ? 1 : ndArray.shape[0]; | |||
| case IEnumerable enumerable: | |||
| return enumerable.OfType<object>().Count(); | |||
| } | |||
| @@ -60,10 +60,13 @@ https://tensorflownet.readthedocs.io</Description> | |||
| <ItemGroup> | |||
| <PackageReference Include="Google.Protobuf" Version="3.11.2" /> | |||
| <PackageReference Include="NumSharp" Version="0.20.5" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <Folder Include="Keras\Initializers\" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using NumSharp; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public abstract class DataSetBase : IDataSet | |||
| { | |||
| public NDArray Data { get; protected set; } | |||
| public NDArray Labels { get; protected set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,46 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using NumSharp; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public class Datasets<TDataSet> where TDataSet : IDataSet | |||
| { | |||
| public TDataSet Train { get; private set; } | |||
| public TDataSet Validation { get; private set; } | |||
| public TDataSet Test { get; private set; } | |||
| public Datasets(TDataSet train, TDataSet validation, TDataSet test) | |||
| { | |||
| Train = train; | |||
| Validation = validation; | |||
| Test = test; | |||
| } | |||
| public (NDArray, NDArray) Randomize(NDArray x, NDArray y) | |||
| { | |||
| var perm = np.random.permutation(y.shape[0]); | |||
| np.random.shuffle(perm); | |||
| return (x[perm], y[perm]); | |||
| } | |||
| /// <summary> | |||
| /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) | |||
| /// </summary> | |||
| /// <param name="x"></param> | |||
| /// <param name="y"></param> | |||
| /// <param name="start"></param> | |||
| /// <param name="end"></param> | |||
| /// <returns></returns> | |||
| public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) | |||
| { | |||
| var slice = new Slice(start, end); | |||
| var x_batch = x[slice]; | |||
| var y_batch = y[slice]; | |||
| return (x_batch, y_batch); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,13 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using NumSharp; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public interface IDataSet | |||
| { | |||
| NDArray Data { get; } | |||
| NDArray Labels { get; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,14 @@ | |||
| using System; | |||
| using System.Threading.Tasks; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using NumSharp; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public interface IModelLoader<TDataSet> | |||
| where TDataSet : IDataSet | |||
| { | |||
| Task<Datasets<TDataSet>> LoadAsync(ModelLoadSetting setting); | |||
| } | |||
| } | |||
| @@ -0,0 +1,88 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Text; | |||
| using NumSharp; | |||
| using Tensorflow; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public class MnistDataSet : DataSetBase | |||
| { | |||
| public int NumOfExamples { get; private set; } | |||
| public int EpochsCompleted { get; private set; } | |||
| public int IndexInEpoch { get; private set; } | |||
| public MnistDataSet(NDArray images, NDArray labels, Type dataType, bool reshape) | |||
| { | |||
| EpochsCompleted = 0; | |||
| IndexInEpoch = 0; | |||
| NumOfExamples = images.shape[0]; | |||
| images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||
| images = images.astype(dataType); | |||
| // for debug np.multiply performance | |||
| var sw = new Stopwatch(); | |||
| sw.Start(); | |||
| images = np.multiply(images, 1.0f / 255.0f); | |||
| sw.Stop(); | |||
| Console.WriteLine($"{sw.ElapsedMilliseconds}ms"); | |||
| Data = images; | |||
| labels = labels.astype(dataType); | |||
| Labels = labels; | |||
| } | |||
| public (NDArray, NDArray) GetNextBatch(int batch_size, bool fake_data = false, bool shuffle = true) | |||
| { | |||
| if (IndexInEpoch >= NumOfExamples) | |||
| IndexInEpoch = 0; | |||
| var start = IndexInEpoch; | |||
| // Shuffle for the first epoch | |||
| if(EpochsCompleted == 0 && start == 0 && shuffle) | |||
| { | |||
| var perm0 = np.arange(NumOfExamples); | |||
| np.random.shuffle(perm0); | |||
| Data = Data[perm0]; | |||
| Labels = Labels[perm0]; | |||
| } | |||
| // Go to the next epoch | |||
| if (start + batch_size > NumOfExamples) | |||
| { | |||
| // Finished epoch | |||
| EpochsCompleted += 1; | |||
| // Get the rest examples in this epoch | |||
| var rest_num_examples = NumOfExamples - start; | |||
| var images_rest_part = Data[np.arange(start, NumOfExamples)]; | |||
| var labels_rest_part = Labels[np.arange(start, NumOfExamples)]; | |||
| // Shuffle the data | |||
| if (shuffle) | |||
| { | |||
| var perm = np.arange(NumOfExamples); | |||
| np.random.shuffle(perm); | |||
| Data = Data[perm]; | |||
| Labels = Labels[perm]; | |||
| } | |||
| start = 0; | |||
| IndexInEpoch = batch_size - rest_num_examples; | |||
| var end = IndexInEpoch; | |||
| var images_new_part = Data[np.arange(start, end)]; | |||
| var labels_new_part = Labels[np.arange(start, end)]; | |||
| return (np.concatenate(new[] { images_rest_part, images_new_part }, axis: 0), | |||
| np.concatenate(new[] { labels_rest_part, labels_new_part }, axis: 0)); | |||
| } | |||
| else | |||
| { | |||
| IndexInEpoch += batch_size; | |||
| var end = IndexInEpoch; | |||
| return (Data[np.arange(start, end)], Labels[np.arange(start, end)]); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,184 @@ | |||
| using System; | |||
| using System.Threading.Tasks; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using System.IO; | |||
| using NumSharp; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public class MnistModelLoader : IModelLoader<MnistDataSet> | |||
| { | |||
| private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; | |||
| private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; | |||
| private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | |||
| private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | |||
| private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | |||
| public static async Task<Datasets<MnistDataSet>> LoadAsync(string trainDir, bool oneHot = false, int? trainSize = null, int? validationSize = null, int? testSize = null, bool showProgressInConsole = false) | |||
| { | |||
| var loader = new MnistModelLoader(); | |||
| var setting = new ModelLoadSetting | |||
| { | |||
| TrainDir = trainDir, | |||
| OneHot = oneHot, | |||
| ShowProgressInConsole = showProgressInConsole | |||
| }; | |||
| if (trainSize.HasValue) | |||
| setting.TrainSize = trainSize.Value; | |||
| if (validationSize.HasValue) | |||
| setting.ValidationSize = validationSize.Value; | |||
| if (testSize.HasValue) | |||
| setting.TestSize = testSize.Value; | |||
| return await loader.LoadAsync(setting); | |||
| } | |||
| public async Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) | |||
| { | |||
| if (setting.TrainSize.HasValue && setting.ValidationSize >= setting.TrainSize.Value) | |||
| throw new ArgumentException("Validation set should be smaller than training set"); | |||
| var sourceUrl = setting.SourceUrl; | |||
| if (string.IsNullOrEmpty(sourceUrl)) | |||
| sourceUrl = DEFAULT_SOURCE_URL; | |||
| // load train images | |||
| await this.DownloadAsync(sourceUrl + TRAIN_IMAGES, setting.TrainDir, TRAIN_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| var trainImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_IMAGES)), limit: setting.TrainSize); | |||
| // load train labels | |||
| await this.DownloadAsync(sourceUrl + TRAIN_LABELS, setting.TrainDir, TRAIN_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TRAIN_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| var trainLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TRAIN_LABELS)), one_hot: setting.OneHot, limit: setting.TrainSize); | |||
| // load test images | |||
| await this.DownloadAsync(sourceUrl + TEST_IMAGES, setting.TrainDir, TEST_IMAGES, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_IMAGES), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| var testImages = ExtractImages(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_IMAGES)), limit: setting.TestSize); | |||
| // load test labels | |||
| await this.DownloadAsync(sourceUrl + TEST_LABELS, setting.TrainDir, TEST_LABELS, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| await this.UnzipAsync(Path.Combine(setting.TrainDir, TEST_LABELS), setting.TrainDir, showProgressInConsole: setting.ShowProgressInConsole) | |||
| .ShowProgressInConsole(setting.ShowProgressInConsole); | |||
| var testLabels = ExtractLabels(Path.Combine(setting.TrainDir, Path.GetFileNameWithoutExtension(TEST_LABELS)), one_hot: setting.OneHot, limit: setting.TestSize); | |||
| var end = trainImages.shape[0]; | |||
| var validationSize = setting.ValidationSize; | |||
| var validationImages = trainImages[np.arange(validationSize)]; | |||
| var validationLabels = trainLabels[np.arange(validationSize)]; | |||
| trainImages = trainImages[np.arange(validationSize, end)]; | |||
| trainLabels = trainLabels[np.arange(validationSize, end)]; | |||
| var dtype = setting.DataType; | |||
| var reshape = setting.ReShape; | |||
| var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape); | |||
| var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape); | |||
| var test = new MnistDataSet(testImages, testLabels, dtype, reshape); | |||
| return new Datasets<MnistDataSet>(train, validation, test); | |||
| } | |||
| private NDArray ExtractImages(string file, int? limit = null) | |||
| { | |||
| if (!Path.IsPathRooted(file)) | |||
| file = Path.Combine(AppContext.BaseDirectory, file); | |||
| using (var bytestream = new FileStream(file, FileMode.Open)) | |||
| { | |||
| var magic = Read32(bytestream); | |||
| if (magic != 2051) | |||
| throw new Exception($"Invalid magic number {magic} in MNIST image file: {file}"); | |||
| var num_images = Read32(bytestream); | |||
| num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); | |||
| var rows = Read32(bytestream); | |||
| var cols = Read32(bytestream); | |||
| var buf = new byte[rows * cols * num_images]; | |||
| bytestream.Read(buf, 0, buf.Length); | |||
| var data = np.frombuffer(buf, np.uint8); | |||
| data = data.reshape((int)num_images, (int)rows, (int)cols, 1); | |||
| return data; | |||
| } | |||
| } | |||
| private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes = 10, int? limit = null) | |||
| { | |||
| if (!Path.IsPathRooted(file)) | |||
| file = Path.Combine(AppContext.BaseDirectory, file); | |||
| using (var bytestream = new FileStream(file, FileMode.Open)) | |||
| { | |||
| var magic = Read32(bytestream); | |||
| if (magic != 2049) | |||
| throw new Exception($"Invalid magic number {magic} in MNIST label file: {file}"); | |||
| var num_items = Read32(bytestream); | |||
| num_items = limit == null ? num_items : Math.Min(num_items, (uint)limit); | |||
| var buf = new byte[num_items]; | |||
| bytestream.Read(buf, 0, buf.Length); | |||
| var labels = np.frombuffer(buf, np.uint8); | |||
| if (one_hot) | |||
| return DenseToOneHot(labels, num_classes); | |||
| return labels; | |||
| } | |||
| } | |||
| private NDArray DenseToOneHot(NDArray labels_dense, int num_classes) | |||
| { | |||
| var num_labels = labels_dense.shape[0]; | |||
| var index_offset = np.arange(num_labels) * num_classes; | |||
| var labels_one_hot = np.zeros(num_labels, num_classes); | |||
| var labels = labels_dense.Data<byte>(); | |||
| for (int row = 0; row < num_labels; row++) | |||
| { | |||
| var col = labels[row]; | |||
| labels_one_hot.SetData(1.0, row, col); | |||
| } | |||
| return labels_one_hot; | |||
| } | |||
| private uint Read32(FileStream bytestream) | |||
| { | |||
| var buffer = new byte[sizeof(uint)]; | |||
| var count = bytestream.Read(buffer, 0, 4); | |||
| return np.frombuffer(buffer, ">u4").Data<uint>()[0]; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,20 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using NumSharp; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public class ModelLoadSetting | |||
| { | |||
| public string TrainDir { get; set; } | |||
| public bool OneHot { get; set; } | |||
| public Type DataType { get; set; } = typeof(float); | |||
| public bool ReShape { get; set; } | |||
| public int ValidationSize { get; set; } = 5000; | |||
| public int? TrainSize { get; set; } | |||
| public int? TestSize { get; set; } | |||
| public string SourceUrl { get; set; } | |||
| public bool ShowProgressInConsole { get; set; } | |||
| } | |||
| } | |||
| @@ -0,0 +1,5 @@ | |||
| ## TensorFlow Hub | |||
| TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models. In particular, it provides **modules**, which are pre-trained pieces of TensorFlow models that can be reused on new tasks. | |||
| https://github.com/tensorflow/hub | |||
| @@ -0,0 +1,26 @@ | |||
| <Project Sdk="Microsoft.NET.Sdk"> | |||
| <PropertyGroup> | |||
| <RootNamespace>Tensorflow.Hub</RootNamespace> | |||
| <TargetFramework>netstandard2.0</TargetFramework> | |||
| <Version>0.0.5</Version> | |||
| <Authors>Kerry Jiang, Haiping Chen</Authors> | |||
| <Company>SciSharp STACK</Company> | |||
| <Copyright>Apache 2.0</Copyright> | |||
| <RepositoryUrl>https://github.com/SciSharp/TensorFlow.NET</RepositoryUrl> | |||
| <RepositoryType>git</RepositoryType> | |||
| <PackageProjectUrl>http://scisharpstack.org</PackageProjectUrl> | |||
| <PackageTags>TensorFlow, SciSharp, MachineLearning</PackageTags> | |||
| <Description>TensorFlow Hub is a library to foster the publication, discovery, and consumption of reusable parts of machine learning models.</Description> | |||
| <PackageId>SciSharp.TensorFlowHub</PackageId> | |||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | |||
| <PackageReleaseNotes>Fix GetNextBatch() bug.</PackageReleaseNotes> | |||
| <PackageIconUrl>https://avatars3.githubusercontent.com/u/44989469?s=200&v=4</PackageIconUrl> | |||
| <AssemblyName>TensorFlow.Hub</AssemblyName> | |||
| </PropertyGroup> | |||
| <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'"> | |||
| <DefineConstants>DEBUG;TRACE</DefineConstants> | |||
| </PropertyGroup> | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||
| </ItemGroup> | |||
| </Project> | |||
| @@ -0,0 +1,137 @@ | |||
| using System; | |||
| using System.IO; | |||
| using System.IO.Compression; | |||
| using System.Collections.Generic; | |||
| using System.Net; | |||
| using System.Text; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| namespace Tensorflow.Hub | |||
| { | |||
| public static class Utils | |||
| { | |||
| public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string saveTo) | |||
| where TDataSet : IDataSet | |||
| { | |||
| var dir = Path.GetDirectoryName(saveTo); | |||
| var fileName = Path.GetFileName(saveTo); | |||
| await modelLoader.DownloadAsync(url, dir, fileName); | |||
| } | |||
| public static async Task DownloadAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string url, string dirSaveTo, string fileName, bool showProgressInConsole = false) | |||
| where TDataSet : IDataSet | |||
| { | |||
| if (!Path.IsPathRooted(dirSaveTo)) | |||
| dirSaveTo = Path.Combine(AppContext.BaseDirectory, dirSaveTo); | |||
| var fileSaveTo = Path.Combine(dirSaveTo, fileName); | |||
| if (showProgressInConsole) | |||
| { | |||
| Console.WriteLine($"Downloading {fileName}"); | |||
| } | |||
| if (File.Exists(fileSaveTo)) | |||
| { | |||
| if (showProgressInConsole) | |||
| { | |||
| Console.WriteLine($"The file {fileName} already exists"); | |||
| } | |||
| return; | |||
| } | |||
| Directory.CreateDirectory(dirSaveTo); | |||
| using (var wc = new WebClient()) | |||
| { | |||
| await wc.DownloadFileTaskAsync(url, fileSaveTo).ConfigureAwait(false); | |||
| } | |||
| } | |||
| public static async Task UnzipAsync<TDataSet>(this IModelLoader<TDataSet> modelLoader, string zipFile, string saveTo, bool showProgressInConsole = false) | |||
| where TDataSet : IDataSet | |||
| { | |||
| if (!Path.IsPathRooted(saveTo)) | |||
| saveTo = Path.Combine(AppContext.BaseDirectory, saveTo); | |||
| Directory.CreateDirectory(saveTo); | |||
| if (!Path.IsPathRooted(zipFile)) | |||
| zipFile = Path.Combine(AppContext.BaseDirectory, zipFile); | |||
| var destFileName = Path.GetFileNameWithoutExtension(zipFile); | |||
| var destFilePath = Path.Combine(saveTo, destFileName); | |||
| if (showProgressInConsole) | |||
| Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); | |||
| if (File.Exists(destFilePath)) | |||
| { | |||
| if (showProgressInConsole) | |||
| Console.WriteLine($"The file {destFileName} already exists"); | |||
| } | |||
| using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) | |||
| { | |||
| using (var destStream = File.Create(destFilePath)) | |||
| { | |||
| await unzipStream.CopyToAsync(destStream).ConfigureAwait(false); | |||
| await destStream.FlushAsync().ConfigureAwait(false); | |||
| destStream.Close(); | |||
| } | |||
| unzipStream.Close(); | |||
| } | |||
| } | |||
| public static async Task ShowProgressInConsole(this Task task, bool enable) | |||
| { | |||
| if (!enable) | |||
| { | |||
| await task; | |||
| return; | |||
| } | |||
| var cts = new CancellationTokenSource(); | |||
| var showProgressTask = ShowProgressInConsole(cts); | |||
| try | |||
| { | |||
| await task; | |||
| } | |||
| finally | |||
| { | |||
| cts.Cancel(); | |||
| } | |||
| await showProgressTask; | |||
| Console.WriteLine("Done."); | |||
| } | |||
| private static async Task ShowProgressInConsole(CancellationTokenSource cts) | |||
| { | |||
| var cols = 0; | |||
| await Task.Delay(100); | |||
| while (!cts.IsCancellationRequested) | |||
| { | |||
| await Task.Delay(100); | |||
| Console.Write("."); | |||
| cols++; | |||
| if (cols % 50 == 0) | |||
| { | |||
| Console.WriteLine(); | |||
| } | |||
| } | |||
| if (cols > 0) | |||
| Console.WriteLine(); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,24 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System.Threading.Tasks; | |||
| using Tensorflow.Hub; | |||
| namespace UnitTest | |||
| { | |||
| [TestClass] | |||
| public class MnistModelLoaderTest | |||
| { | |||
| [TestMethod] | |||
| public async Task TestLoad() | |||
| { | |||
| var loader = new MnistModelLoader(); | |||
| var result = await loader.LoadAsync(new ModelLoadSetting | |||
| { | |||
| TrainDir = "mnist", | |||
| OneHot = true, | |||
| ValidationSize = 5000, | |||
| }); | |||
| Assert.IsNotNull(result); | |||
| } | |||
| } | |||
| } | |||
| @@ -37,6 +37,7 @@ | |||
| <ItemGroup> | |||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" /> | |||
| <ProjectReference Include="..\..\src\TensorFlowNET.Hub\Tensorflow.Hub.csproj" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||