Browse Source

fix: Examples project uses all data, unit test uses only small fraction

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
1ff31db07e
5 changed files with 19 additions and 16 deletions
  1. +2
    -2
      test/TensorFlowNET.Examples/LinearRegression.cs
  2. +5
    -4
      test/TensorFlowNET.Examples/LogisticRegression.cs
  3. +6
    -5
      test/TensorFlowNET.Examples/NearestNeighbor.cs
  4. +4
    -3
      test/TensorFlowNET.Examples/Utility/MnistDataSet.cs
  5. +2
    -2
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 2
- 2
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -20,7 +20,7 @@ namespace TensorFlowNET.Examples

// Parameters
float learning_rate = 0.01f;
int training_epochs = 1000;
public int TrainingEpochs = 1000;
int display_step = 50;

NDArray train_X, train_Y;
@@ -62,7 +62,7 @@ namespace TensorFlowNET.Examples
sess.run(init);

// Fit all training data
for (int epoch = 0; epoch < training_epochs; epoch++)
for (int epoch = 0; epoch < TrainingEpochs; epoch++)
{
foreach (var (x, y) in zip<float>(train_X, train_Y))
{


+ 5
- 4
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -22,8 +22,9 @@ namespace TensorFlowNET.Examples

private float learning_rate = 0.01f;
public int TrainingEpochs = 10;
public int DataSize = 5000;
public int TestSize = 5000;
public int? TrainSize = null;
public int ValidationSize = 5000;
public int? TestSize = null;
public int BatchSize = 100;
private int display_step = 1;

@@ -98,7 +99,7 @@ namespace TensorFlowNET.Examples

public void PrepareData()
{
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size: TestSize);
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size: ValidationSize, test_size: TestSize);
}

public void SaveModel(Session sess)
@@ -141,7 +142,7 @@ namespace TensorFlowNET.Examples
if (results.argmax() == (batch_ys[0] as NDArray).argmax())
print("predicted OK!");
else
throw new ValueError("predict error, maybe 90% accuracy");
throw new ValueError("predict error, should be 90% accuracy");
});
}
}


+ 6
- 5
test/TensorFlowNET.Examples/NearestNeighbor.cs View File

@@ -19,8 +19,9 @@ namespace TensorFlowNET.Examples
public string Name => "Nearest Neighbor";
Datasets mnist;
NDArray Xtr, Ytr, Xte, Yte;
public int DataSize = 5000;
public int TestBatchSize = 200;
public int? TrainSize = null;
public int ValidationSize = 5000;
public int? TestSize = null;

public bool Run()
{
@@ -64,10 +65,10 @@ namespace TensorFlowNET.Examples

public void PrepareData()
{
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize);
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize);
// In this example, we limit mnist data
(Xtr, Ytr) = mnist.train.next_batch(DataSize); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(TestBatchSize); // 200 for testing
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
}
}
}

+ 4
- 3
test/TensorFlowNET.Examples/Utility/MnistDataSet.cs View File

@@ -15,16 +15,17 @@ namespace TensorFlowNET.Examples.Utility
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";

public static Datasets read_data_sets(string train_dir,
bool one_hot = false,
TF_DataType dtype = TF_DataType.TF_FLOAT,
bool reshape = true,
int validation_size = 5000,
int test_size = 5000,
int? train_size = null,
int? test_size = null,
string source_url = DEFAULT_SOURCE_URL)
{
var train_size = validation_size * 2;
if (train_size!=null && validation_size >= train_size)
throw new ArgumentException("Validation set should be smaller than training set");
Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);


+ 2
- 2
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -51,7 +51,7 @@ namespace TensorFlowNET.UnitTest.ExamplesTests
[TestMethod]
public void LogisticRegression()
{
new LogisticRegression() { Enabled = true, TrainingEpochs=10, DataSize = 500, TestSize = 500 }.Run();
new LogisticRegression() { Enabled = true, TrainingEpochs=10, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
}
[Ignore]
@@ -78,7 +78,7 @@ namespace TensorFlowNET.UnitTest.ExamplesTests
[TestMethod]
public void NearestNeighbor()
{
new NearestNeighbor() { Enabled = true, DataSize = 500, TestBatchSize = 100 }.Run();
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
}
[Ignore]


Loading…
Cancel
Save