| @@ -20,7 +20,7 @@ namespace TensorFlowNET.Examples | |||||
| // Parameters | // Parameters | ||||
| float learning_rate = 0.01f; | float learning_rate = 0.01f; | ||||
| int training_epochs = 1000; | |||||
| public int TrainingEpochs = 1000; | |||||
| int display_step = 50; | int display_step = 50; | ||||
| NDArray train_X, train_Y; | NDArray train_X, train_Y; | ||||
| @@ -62,7 +62,7 @@ namespace TensorFlowNET.Examples | |||||
| sess.run(init); | sess.run(init); | ||||
| // Fit all training data | // Fit all training data | ||||
| for (int epoch = 0; epoch < training_epochs; epoch++) | |||||
| for (int epoch = 0; epoch < TrainingEpochs; epoch++) | |||||
| { | { | ||||
| foreach (var (x, y) in zip<float>(train_X, train_Y)) | foreach (var (x, y) in zip<float>(train_X, train_Y)) | ||||
| { | { | ||||
| @@ -22,8 +22,9 @@ namespace TensorFlowNET.Examples | |||||
| private float learning_rate = 0.01f; | private float learning_rate = 0.01f; | ||||
| public int TrainingEpochs = 10; | public int TrainingEpochs = 10; | ||||
| public int DataSize = 5000; | |||||
| public int TestSize = 5000; | |||||
| public int? TrainSize = null; | |||||
| public int ValidationSize = 5000; | |||||
| public int? TestSize = null; | |||||
| public int BatchSize = 100; | public int BatchSize = 100; | ||||
| private int display_step = 1; | private int display_step = 1; | ||||
| @@ -98,7 +99,7 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size: TestSize); | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size: ValidationSize, test_size: TestSize); | |||||
| } | } | ||||
| public void SaveModel(Session sess) | public void SaveModel(Session sess) | ||||
| @@ -141,7 +142,7 @@ namespace TensorFlowNET.Examples | |||||
| if (results.argmax() == (batch_ys[0] as NDArray).argmax()) | if (results.argmax() == (batch_ys[0] as NDArray).argmax()) | ||||
| print("predicted OK!"); | print("predicted OK!"); | ||||
| else | else | ||||
| throw new ValueError("predict error, maybe 90% accuracy"); | |||||
| throw new ValueError("predict error, should be 90% accuracy"); | |||||
| }); | }); | ||||
| } | } | ||||
| } | } | ||||
| @@ -19,8 +19,9 @@ namespace TensorFlowNET.Examples | |||||
| public string Name => "Nearest Neighbor"; | public string Name => "Nearest Neighbor"; | ||||
| Datasets mnist; | Datasets mnist; | ||||
| NDArray Xtr, Ytr, Xte, Yte; | NDArray Xtr, Ytr, Xte, Yte; | ||||
| public int DataSize = 5000; | |||||
| public int TestBatchSize = 200; | |||||
| public int? TrainSize = null; | |||||
| public int ValidationSize = 5000; | |||||
| public int? TestSize = null; | |||||
| public bool Run() | public bool Run() | ||||
| { | { | ||||
| @@ -64,10 +65,10 @@ namespace TensorFlowNET.Examples | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize); | |||||
| mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize); | |||||
| // In this example, we limit mnist data | // In this example, we limit mnist data | ||||
| (Xtr, Ytr) = mnist.train.next_batch(DataSize); // 5000 for training (nn candidates) | |||||
| (Xte, Yte) = mnist.test.next_batch(TestBatchSize); // 200 for testing | |||||
| (Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) | |||||
| (Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -15,16 +15,17 @@ namespace TensorFlowNET.Examples.Utility | |||||
| private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; | ||||
| private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; | ||||
| private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; | ||||
| public static Datasets read_data_sets(string train_dir, | public static Datasets read_data_sets(string train_dir, | ||||
| bool one_hot = false, | bool one_hot = false, | ||||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | TF_DataType dtype = TF_DataType.TF_FLOAT, | ||||
| bool reshape = true, | bool reshape = true, | ||||
| int validation_size = 5000, | int validation_size = 5000, | ||||
| int test_size = 5000, | |||||
| int? train_size = null, | |||||
| int? test_size = null, | |||||
| string source_url = DEFAULT_SOURCE_URL) | string source_url = DEFAULT_SOURCE_URL) | ||||
| { | { | ||||
| var train_size = validation_size * 2; | |||||
| if (train_size!=null && validation_size >= train_size) | |||||
| throw new ArgumentException("Validation set should be smaller than training set"); | |||||
| Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); | ||||
| Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); | ||||
| @@ -51,7 +51,7 @@ namespace TensorFlowNET.UnitTest.ExamplesTests | |||||
| [TestMethod] | [TestMethod] | ||||
| public void LogisticRegression() | public void LogisticRegression() | ||||
| { | { | ||||
| new LogisticRegression() { Enabled = true, TrainingEpochs=10, DataSize = 500, TestSize = 500 }.Run(); | |||||
| new LogisticRegression() { Enabled = true, TrainingEpochs=10, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); | |||||
| } | } | ||||
| [Ignore] | [Ignore] | ||||
| @@ -78,7 +78,7 @@ namespace TensorFlowNET.UnitTest.ExamplesTests | |||||
| [TestMethod] | [TestMethod] | ||||
| public void NearestNeighbor() | public void NearestNeighbor() | ||||
| { | { | ||||
| new NearestNeighbor() { Enabled = true, DataSize = 500, TestBatchSize = 100 }.Run(); | |||||
| new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run(); | |||||
| } | } | ||||
| [Ignore] | [Ignore] | ||||