Browse Source

TestSuite: added all examples with very small training sets (runs through within seconds)

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
30dde0fde9
7 changed files with 40 additions and 26 deletions
  1. +4
    -1
      test/TensorFlowNET.Examples/KMeansClustering.cs
  2. +8
    -6
      test/TensorFlowNET.Examples/LogisticRegression.cs
  3. +5
    -3
      test/TensorFlowNET.Examples/NearestNeighbor.cs
  4. +6
    -5
      test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
  5. +2
    -1
      test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
  6. +12
    -7
      test/TensorFlowNET.Examples/Utility/MnistDataSet.cs
  7. +3
    -3
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 4
- 1
test/TensorFlowNET.Examples/KMeansClustering.cs View File

@@ -18,6 +18,9 @@ namespace TensorFlowNET.Examples
public int Priority => 8; public int Priority => 8;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "K-means Clustering"; public string Name => "K-means Clustering";
public int DataSize = 5000;
public int TestSize = 5000;
public int BatchSize = 100;


Datasets mnist; Datasets mnist;
NDArray full_data_x; NDArray full_data_x;
@@ -45,7 +48,7 @@ namespace TensorFlowNET.Examples


public void PrepareData() public void PrepareData()
{ {
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size:TestSize);
full_data_x = mnist.train.images; full_data_x = mnist.train.images;
} }
} }


+ 8
- 6
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -21,8 +21,10 @@ namespace TensorFlowNET.Examples
public string Name => "Logistic Regression"; public string Name => "Logistic Regression";


private float learning_rate = 0.01f; private float learning_rate = 0.01f;
private int training_epochs = 10;
private int batch_size = 100;
public int TrainingEpochs = 10;
public int DataSize = 5000;
public int TestSize = 5000;
public int BatchSize = 100;
private int display_step = 1; private int display_step = 1;


Datasets mnist; Datasets mnist;
@@ -57,14 +59,14 @@ namespace TensorFlowNET.Examples
sess.run(init); sess.run(init);


// Training cycle // Training cycle
foreach (var epoch in range(training_epochs))
foreach (var epoch in range(TrainingEpochs))
{ {
var avg_cost = 0.0f; var avg_cost = 0.0f;
var total_batch = mnist.train.num_examples / batch_size;
var total_batch = mnist.train.num_examples / BatchSize;
// Loop over all batches // Loop over all batches
foreach (var i in range(total_batch)) foreach (var i in range(total_batch))
{ {
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
var (batch_xs, batch_ys) = mnist.train.next_batch(BatchSize);
// Run optimization op (backprop) and cost op (to get loss value) // Run optimization op (backprop) and cost op (to get loss value)
var result = sess.run(new object[] { optimizer, cost }, var result = sess.run(new object[] { optimizer, cost },
new FeedItem(x, batch_xs), new FeedItem(x, batch_xs),
@@ -96,7 +98,7 @@ namespace TensorFlowNET.Examples


public void PrepareData() public void PrepareData()
{ {
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size: TestSize);
} }


public void SaveModel(Session sess) public void SaveModel(Session sess)


+ 5
- 3
test/TensorFlowNET.Examples/NearestNeighbor.cs View File

@@ -19,6 +19,8 @@ namespace TensorFlowNET.Examples
public string Name => "Nearest Neighbor"; public string Name => "Nearest Neighbor";
Datasets mnist; Datasets mnist;
NDArray Xtr, Ytr, Xte, Yte; NDArray Xtr, Ytr, Xte, Yte;
public int DataSize = 5000;
public int TestBatchSize = 200;


public bool Run() public bool Run()
{ {
@@ -62,10 +64,10 @@ namespace TensorFlowNET.Examples


public void PrepareData() public void PrepareData()
{ {
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize);
// In this example, we limit mnist data // In this example, we limit mnist data
(Xtr, Ytr) = mnist.train.next_batch(5000); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(200); // 200 for testing
(Xtr, Ytr) = mnist.train.next_batch(DataSize); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(TestBatchSize); // 200 for testing
} }
} }
} }

+ 6
- 5
test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs View File

@@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";


public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len)
public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit=null)
{ {
string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} "; string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} ";
/*if (step == "train") /*if (step == "train")
@@ -25,10 +25,11 @@ namespace TensorFlowNET.Examples.CnnTextClassification
char_dict[c.ToString()] = char_dict.Count; char_dict[c.ToString()] = char_dict.Count;


var contents = File.ReadAllLines(TRAIN_PATH); var contents = File.ReadAllLines(TRAIN_PATH);
var x = new int[contents.Length][];
var y = new int[contents.Length];
for (int i = 0; i < contents.Length; i++)
var size = limit == null ? contents.Length : limit.Value;

var x = new int[size][];
var y = new int[size];
for (int i = 0; i < size; i++)
{ {
string[] parts = contents[i].ToLower().Split(",\"").ToArray(); string[] parts = contents[i].ToLower().Split(",\"").ToArray();
string content = parts[2]; string content = parts[2];


+ 2
- 1
test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs View File

@@ -17,6 +17,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
public int Priority => 100; public int Priority => 100;
public bool Enabled { get; set; }= false; public bool Enabled { get; set; }= false;
public string Name => "Text Classification"; public string Name => "Text Classification";
public int? DataLimit = null;


private string dataDir = "text_classification"; private string dataDir = "text_classification";
private string dataFileName = "dbpedia_csv.tar.gz"; private string dataFileName = "dbpedia_csv.tar.gz";
@@ -28,7 +29,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
{ {
PrepareData(); PrepareData();
Console.WriteLine("Building dataset..."); Console.WriteLine("Building dataset...");
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN);
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN, DataLimit);


var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);




+ 12
- 7
test/TensorFlowNET.Examples/Utility/MnistDataSet.cs View File

@@ -21,23 +21,26 @@ namespace TensorFlowNET.Examples.Utility
TF_DataType dtype = TF_DataType.TF_FLOAT, TF_DataType dtype = TF_DataType.TF_FLOAT,
bool reshape = true, bool reshape = true,
int validation_size = 5000, int validation_size = 5000,
int test_size = 5000,
string source_url = DEFAULT_SOURCE_URL) string source_url = DEFAULT_SOURCE_URL)
{ {
var train_size = validation_size * 2;
Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);
var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]));
var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size);


Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS);
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir);
var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot);
var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot, limit: train_size);


Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES);
Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir);
var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]));
var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]), limit: test_size);


Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS);
Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir);
var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot);
var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot, limit:test_size);


int end = train_images.shape[0]; int end = train_images.shape[0];
var validation_images = train_images[np.arange(validation_size)]; var validation_images = train_images[np.arange(validation_size)];
@@ -52,14 +55,15 @@ namespace TensorFlowNET.Examples.Utility
return new Datasets(train, validation, test); return new Datasets(train, validation, test);
} }


public static NDArray extract_images(string file)
public static NDArray extract_images(string file, int? limit=null)
{ {
using (var bytestream = new FileStream(file, FileMode.Open)) using (var bytestream = new FileStream(file, FileMode.Open))
{ {
var magic = _read32(bytestream); var magic = _read32(bytestream);
if (magic != 2051) if (magic != 2051)
throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
var num_images = _read32(bytestream);
var num_images = _read32(bytestream);
num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit);
var rows = _read32(bytestream); var rows = _read32(bytestream);
var cols = _read32(bytestream); var cols = _read32(bytestream);
var buf = new byte[rows * cols * num_images]; var buf = new byte[rows * cols * num_images];
@@ -70,7 +74,7 @@ namespace TensorFlowNET.Examples.Utility
} }
} }


public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10)
public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10, int? limit = null)
{ {
using (var bytestream = new FileStream(file, FileMode.Open)) using (var bytestream = new FileStream(file, FileMode.Open))
{ {
@@ -78,6 +82,7 @@ namespace TensorFlowNET.Examples.Utility
if (magic != 2049) if (magic != 2049)
throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
var num_items = _read32(bytestream); var num_items = _read32(bytestream);
num_items = limit == null ? num_items : Math.Min(num_items,(uint) limit);
var buf = new byte[num_items]; var buf = new byte[num_items];
bytestream.Read(buf, 0, buf.Length); bytestream.Read(buf, 0, buf.Length);
var labels = np.frombuffer(buf, np.uint8); var labels = np.frombuffer(buf, np.uint8);


+ 3
- 3
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -51,7 +51,7 @@ namespace TensorFlowNET.UnitTest.ExamplesTests
[TestMethod] [TestMethod]
public void LogisticRegression() public void LogisticRegression()
{ {
new LogisticRegression() { Enabled = true }.Run();
new LogisticRegression() { Enabled = true, TrainingEpochs=10, DataSize = 500, TestSize = 500 }.Run();
} }
[Ignore] [Ignore]
@@ -78,14 +78,14 @@ namespace TensorFlowNET.UnitTest.ExamplesTests
[TestMethod] [TestMethod]
public void NearestNeighbor() public void NearestNeighbor()
{ {
new NearestNeighbor() { Enabled = true }.Run();
new NearestNeighbor() { Enabled = true, DataSize = 500, TestBatchSize = 100 }.Run();
} }
[Ignore] [Ignore]
[TestMethod] [TestMethod]
public void TextClassificationTrain() public void TextClassificationTrain()
{ {
new TextClassificationTrain() { Enabled = true }.Run();
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run();
} }
[Ignore] [Ignore]


Loading…
Cancel
Save