diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index 15ebfc54..83764f40 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -20,7 +20,7 @@ namespace TensorFlowNET.Examples // Parameters float learning_rate = 0.01f; - int training_epochs = 1000; + public int TrainingEpochs = 1000; int display_step = 50; NDArray train_X, train_Y; @@ -62,7 +62,7 @@ namespace TensorFlowNET.Examples sess.run(init); // Fit all training data - for (int epoch = 0; epoch < training_epochs; epoch++) + for (int epoch = 0; epoch < TrainingEpochs; epoch++) { foreach (var (x, y) in zip(train_X, train_Y)) { diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs index 89aaa27a..944df714 100644 --- a/test/TensorFlowNET.Examples/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/LogisticRegression.cs @@ -22,8 +22,9 @@ namespace TensorFlowNET.Examples private float learning_rate = 0.01f; public int TrainingEpochs = 10; - public int DataSize = 5000; - public int TestSize = 5000; + public int? TrainSize = null; + public int ValidationSize = 5000; + public int? TestSize = null; public int BatchSize = 100; private int display_step = 1; @@ -98,7 +99,7 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size: TestSize); + mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size: ValidationSize, test_size: TestSize); } public void SaveModel(Session sess) @@ -141,7 +142,7 @@ namespace TensorFlowNET.Examples if (results.argmax() == (batch_ys[0] as NDArray).argmax()) print("predicted OK!"); else - throw new ValueError("predict error, maybe 90% accuracy"); + throw new ValueError("predict error, should be 90% accuracy"); }); } } diff --git a/test/TensorFlowNET.Examples/NearestNeighbor.cs b/test/TensorFlowNET.Examples/NearestNeighbor.cs index 6894009a..dd5624f1 100644 --- a/test/TensorFlowNET.Examples/NearestNeighbor.cs +++ b/test/TensorFlowNET.Examples/NearestNeighbor.cs @@ -19,8 +19,9 @@ namespace TensorFlowNET.Examples public string Name => "Nearest Neighbor"; Datasets mnist; NDArray Xtr, Ytr, Xte, Yte; - public int DataSize = 5000; - public int TestBatchSize = 200; + public int? TrainSize = null; + public int ValidationSize = 5000; + public int? TestSize = null; public bool Run() { @@ -64,10 +65,10 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize); + mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize); // In this example, we limit mnist data - (Xtr, Ytr) = mnist.train.next_batch(DataSize); // 5000 for training (nn candidates) - (Xte, Yte) = mnist.test.next_batch(TestBatchSize); // 200 for testing + (Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) + (Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing } } } diff --git a/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs b/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs index 7616449c..d4ee0824 100644 --- a/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs +++ b/test/TensorFlowNET.Examples/Utility/MnistDataSet.cs @@ -15,16 +15,17 @@ namespace TensorFlowNET.Examples.Utility private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; - public static Datasets read_data_sets(string train_dir, bool one_hot = false, TF_DataType dtype = TF_DataType.TF_FLOAT, bool reshape = true, int validation_size = 5000, - int test_size = 5000, + int? train_size = null, + int? test_size = null, string source_url = DEFAULT_SOURCE_URL) { - var train_size = validation_size * 2; + if (train_size!=null && validation_size >= train_size) + throw new ArgumentException("Validation set should be smaller than training set"); Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); diff --git a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs index e884708c..5c74320f 100644 --- a/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs +++ b/test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs @@ -51,7 +51,7 @@ namespace TensorFlowNET.UnitTest.ExamplesTests [TestMethod] public void LogisticRegression() { - new LogisticRegression() { Enabled = true, TrainingEpochs=10, DataSize = 500, TestSize = 500 }.Run(); + new LogisticRegression() { Enabled = true, TrainingEpochs=10, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); } [Ignore] @@ -78,7 +78,7 @@ namespace TensorFlowNET.UnitTest.ExamplesTests [TestMethod] public void NearestNeighbor() { - new NearestNeighbor() { Enabled = true, DataSize = 500, TestBatchSize = 100 }.Run(); + new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); } [Ignore]