Browse Source

fix: Examples project uses all data, unit test uses only small fraction

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
1ff31db07e
5 changed files with 19 additions and 16 deletions
  1. +2
    -2
      test/TensorFlowNET.Examples/LinearRegression.cs
  2. +5
    -4
      test/TensorFlowNET.Examples/LogisticRegression.cs
  3. +6
    -5
      test/TensorFlowNET.Examples/NearestNeighbor.cs
  4. +4
    -3
      test/TensorFlowNET.Examples/Utility/MnistDataSet.cs
  5. +2
    -2
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 2
- 2
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -20,7 +20,7 @@ namespace TensorFlowNET.Examples


// Parameters // Parameters
float learning_rate = 0.01f; float learning_rate = 0.01f;
int training_epochs = 1000;
public int TrainingEpochs = 1000;
int display_step = 50; int display_step = 50;


NDArray train_X, train_Y; NDArray train_X, train_Y;
@@ -62,7 +62,7 @@ namespace TensorFlowNET.Examples
sess.run(init); sess.run(init);


// Fit all training data // Fit all training data
for (int epoch = 0; epoch < training_epochs; epoch++)
for (int epoch = 0; epoch < TrainingEpochs; epoch++)
{ {
foreach (var (x, y) in zip<float>(train_X, train_Y)) foreach (var (x, y) in zip<float>(train_X, train_Y))
{ {


+ 5
- 4
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -22,8 +22,9 @@ namespace TensorFlowNET.Examples


private float learning_rate = 0.01f; private float learning_rate = 0.01f;
public int TrainingEpochs = 10; public int TrainingEpochs = 10;
public int DataSize = 5000;
public int TestSize = 5000;
public int? TrainSize = null;
public int ValidationSize = 5000;
public int? TestSize = null;
public int BatchSize = 100; public int BatchSize = 100;
private int display_step = 1; private int display_step = 1;


@@ -98,7 +99,7 @@ namespace TensorFlowNET.Examples


public void PrepareData() public void PrepareData()
{ {
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size: TestSize);
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size: ValidationSize, test_size: TestSize);
} }


public void SaveModel(Session sess) public void SaveModel(Session sess)
@@ -141,7 +142,7 @@ namespace TensorFlowNET.Examples
if (results.argmax() == (batch_ys[0] as NDArray).argmax()) if (results.argmax() == (batch_ys[0] as NDArray).argmax())
print("predicted OK!"); print("predicted OK!");
else else
throw new ValueError("predict error, maybe 90% accuracy");
throw new ValueError("predict error, should be 90% accuracy");
}); });
} }
} }


+ 6
- 5
test/TensorFlowNET.Examples/NearestNeighbor.cs View File

@@ -19,8 +19,9 @@ namespace TensorFlowNET.Examples
public string Name => "Nearest Neighbor"; public string Name => "Nearest Neighbor";
Datasets mnist; Datasets mnist;
NDArray Xtr, Ytr, Xte, Yte; NDArray Xtr, Ytr, Xte, Yte;
public int DataSize = 5000;
public int TestBatchSize = 200;
public int? TrainSize = null;
public int ValidationSize = 5000;
public int? TestSize = null;


public bool Run() public bool Run()
{ {
@@ -64,10 +65,10 @@ namespace TensorFlowNET.Examples


public void PrepareData() public void PrepareData()
{ {
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize);
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize);
// In this example, we limit mnist data // In this example, we limit mnist data
(Xtr, Ytr) = mnist.train.next_batch(DataSize); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(TestBatchSize); // 200 for testing
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
} }
} }
} }

+ 4
- 3
test/TensorFlowNET.Examples/Utility/MnistDataSet.cs View File

@@ -15,16 +15,17 @@ namespace TensorFlowNET.Examples.Utility
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";

public static Datasets read_data_sets(string train_dir, public static Datasets read_data_sets(string train_dir,
bool one_hot = false, bool one_hot = false,
TF_DataType dtype = TF_DataType.TF_FLOAT, TF_DataType dtype = TF_DataType.TF_FLOAT,
bool reshape = true, bool reshape = true,
int validation_size = 5000, int validation_size = 5000,
int test_size = 5000,
int? train_size = null,
int? test_size = null,
string source_url = DEFAULT_SOURCE_URL) string source_url = DEFAULT_SOURCE_URL)
{ {
var train_size = validation_size * 2;
if (train_size!=null && validation_size >= train_size)
throw new ArgumentException("Validation set should be smaller than training set");
Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);


+ 2
- 2
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -51,7 +51,7 @@ namespace TensorFlowNET.UnitTest.ExamplesTests
[TestMethod] [TestMethod]
public void LogisticRegression() public void LogisticRegression()
{ {
new LogisticRegression() { Enabled = true, TrainingEpochs=10, DataSize = 500, TestSize = 500 }.Run();
new LogisticRegression() { Enabled = true, TrainingEpochs=10, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
} }
[Ignore] [Ignore]
@@ -78,7 +78,7 @@ namespace TensorFlowNET.UnitTest.ExamplesTests
[TestMethod] [TestMethod]
public void NearestNeighbor() public void NearestNeighbor()
{ {
new NearestNeighbor() { Enabled = true, DataSize = 500, TestBatchSize = 100 }.Run();
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
} }
[Ignore] [Ignore]


Loading…
Cancel
Save