Browse Source

add train and predict interfaces to IExample.

tags/v0.9
Oceania2018 6 years ago
parent
commit
c42f8bbf88
25 changed files with 503 additions and 87 deletions
  1. +21
    -2
      test/TensorFlowNET.Examples/BasicEagerApi.cs
  2. +21
    -2
      test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs
  3. +21
    -3
      test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs
  4. +21
    -2
      test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs
  5. +22
    -3
      test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs
  6. +21
    -2
      test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs
  7. +23
    -4
      test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs
  8. +21
    -3
      test/TensorFlowNET.Examples/BasicOperations.cs
  9. +21
    -3
      test/TensorFlowNET.Examples/HelloWorld.cs
  10. +13
    -8
      test/TensorFlowNET.Examples/IExample.cs
  11. +21
    -3
      test/TensorFlowNET.Examples/ImageProcess/ImageBackgroundRemoval.cs
  12. +21
    -2
      test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs
  13. +21
    -2
      test/TensorFlowNET.Examples/ImageProcess/InceptionArchGoogLeNet.cs
  14. +21
    -2
      test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs
  15. +21
    -1
      test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs
  16. +4
    -5
      test/TensorFlowNET.Examples/Program.cs
  17. +21
    -2
      test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs
  18. +23
    -4
      test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
  19. +21
    -3
      test/TensorFlowNET.Examples/TextProcess/NER/BiLstmCrfNer.cs
  20. +21
    -3
      test/TensorFlowNET.Examples/TextProcess/NER/CRF.cs
  21. +21
    -3
      test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs
  22. +22
    -3
      test/TensorFlowNET.Examples/TextProcess/NamedEntityRecognition.cs
  23. +24
    -5
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs
  24. +21
    -2
      test/TensorFlowNET.Examples/TextProcess/Word2Vec.cs
  25. +15
    -15
      test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs

+ 21
- 2
test/TensorFlowNET.Examples/BasicEagerApi.cs View File

@@ -11,10 +11,9 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class BasicEagerApi : IExample public class BasicEagerApi : IExample
{ {
public int Priority => 100;
public bool Enabled { get; set; } = false; public bool Enabled { get; set; } = false;
public string Name => "Basic Eager"; public string Name => "Basic Eager";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


private Tensor a, b, c, d; private Tensor a, b, c, d;


@@ -46,5 +45,25 @@ namespace TensorFlowNET.Examples
public void PrepareData() public void PrepareData()
{ {
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 2
test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs View File

@@ -18,10 +18,9 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class KMeansClustering : IExample public class KMeansClustering : IExample
{ {
public int Priority => 8;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "K-means Clustering"; public string Name => "K-means Clustering";
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;


public int? train_size = null; public int? train_size = null;
public int validation_size = 5000; public int validation_size = 5000;
@@ -127,5 +126,25 @@ namespace TensorFlowNET.Examples
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta"; string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta";
Web.Download(url, "graph", "kmeans.meta"); Web.Download(url, "graph", "kmeans.meta");
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 3
test/TensorFlowNET.Examples/BasicModels/LinearRegression.cs View File

@@ -13,11 +13,9 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class LinearRegression : IExample public class LinearRegression : IExample
{ {
public int Priority => 3;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Linear Regression"; public string Name => "Linear Regression";
public bool ImportGraph { get; set; } = false;

public bool IsImportingGraph { get; set; } = false;


public int training_epochs = 1000; public int training_epochs = 1000;


@@ -113,5 +111,25 @@ namespace TensorFlowNET.Examples
2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
n_samples = train_X.shape[0]; n_samples = train_X.shape[0];
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 2
test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs View File

@@ -18,10 +18,9 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class LogisticRegression : IExample public class LogisticRegression : IExample
{ {
public int Priority => 4;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Logistic Regression"; public string Name => "Logistic Regression";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;




public int training_epochs = 10; public int training_epochs = 10;
@@ -158,5 +157,25 @@ namespace TensorFlowNET.Examples
throw new ValueError("predict error, should be 90% accuracy"); throw new ValueError("predict error, should be 90% accuracy");
}); });
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

bool IExample.Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 22
- 3
test/TensorFlowNET.Examples/BasicModels/NaiveBayesClassifier.cs View File

@@ -13,10 +13,9 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class NaiveBayesClassifier : IExample public class NaiveBayesClassifier : IExample
{ {
public int Priority => 6;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Naive Bayes Classifier"; public string Name => "Naive Bayes Classifier";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


public NDArray X, y; public NDArray X, y;
public Normal dist { get; set; } public Normal dist { get; set; }
@@ -96,7 +95,7 @@ namespace TensorFlowNET.Examples
this.dist = dist; this.dist = dist;
} }


public Tensor predict (NDArray X)
public Tensor predict(NDArray X)
{ {
if (dist == null) if (dist == null)
{ {
@@ -170,5 +169,25 @@ namespace TensorFlowNET.Examples
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2); 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2);
#endregion #endregion
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 2
test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs View File

@@ -15,7 +15,6 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class NearestNeighbor : IExample public class NearestNeighbor : IExample
{ {
public int Priority => 5;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Nearest Neighbor"; public string Name => "Nearest Neighbor";
Datasets mnist; Datasets mnist;
@@ -23,7 +22,7 @@ namespace TensorFlowNET.Examples
public int? TrainSize = null; public int? TrainSize = null;
public int ValidationSize = 5000; public int ValidationSize = 5000;
public int? TestSize = null; public int? TestSize = null;
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;




public bool Run() public bool Run()
@@ -76,5 +75,25 @@ namespace TensorFlowNET.Examples
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) (Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing (Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 23
- 4
test/TensorFlowNET.Examples/BasicModels/NeuralNetXor.cs View File

@@ -14,10 +14,9 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class NeuralNetXor : IExample public class NeuralNetXor : IExample
{ {
public int Priority => 10;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "NN XOR"; public string Name => "NN XOR";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


public int num_steps = 10000; public int num_steps = 10000;


@@ -54,7 +53,7 @@ namespace TensorFlowNET.Examples
{ {
PrepareData(); PrepareData();
float loss_value = 0; float loss_value = 0;
if (ImportGraph)
if (IsImportingGraph)
loss_value = RunWithImportedGraph(); loss_value = RunWithImportedGraph();
else else
loss_value = RunWithBuiltGraph(); loss_value = RunWithBuiltGraph();
@@ -145,12 +144,32 @@ namespace TensorFlowNET.Examples
{0, 1 } {0, 1 }
}; };


if (ImportGraph)
if (IsImportingGraph)
{ {
// download graph meta data // download graph meta data
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/xor.meta"; string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/xor.meta";
Web.Download(url, "graph", "xor.meta"); Web.Download(url, "graph", "xor.meta");
} }
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 3
test/TensorFlowNET.Examples/BasicOperations.cs View File

@@ -14,10 +14,8 @@ namespace TensorFlowNET.Examples
public class BasicOperations : IExample public class BasicOperations : IExample
{ {
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public int Priority => 2;
public string Name => "Basic Operations"; public string Name => "Basic Operations";
public bool ImportGraph { get; set; } = false;

public bool IsImportingGraph { get; set; } = false;


private Session sess; private Session sess;


@@ -104,5 +102,25 @@ namespace TensorFlowNET.Examples
public void PrepareData() public void PrepareData()
{ {
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 3
test/TensorFlowNET.Examples/HelloWorld.cs View File

@@ -12,11 +12,9 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class HelloWorld : IExample public class HelloWorld : IExample
{ {
public int Priority => 1;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Hello World"; public string Name => "Hello World";
public bool ImportGraph { get; set; } = false;

public bool IsImportingGraph { get; set; } = false;


public bool Run() public bool Run()
{ {
@@ -41,5 +39,25 @@ namespace TensorFlowNET.Examples
public void PrepareData() public void PrepareData()
{ {
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 13
- 8
test/TensorFlowNET.Examples/IExample.cs View File

@@ -1,7 +1,8 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;

using Tensorflow;
namespace TensorFlowNET.Examples namespace TensorFlowNET.Examples
{ {
/// <summary> /// <summary>
@@ -10,11 +11,6 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public interface IExample public interface IExample
{ {
/// <summary>
/// running order
/// </summary>
int Priority { get; }

/// <summary> /// <summary>
/// True to run example /// True to run example
/// </summary> /// </summary>
@@ -23,15 +19,24 @@ namespace TensorFlowNET.Examples
/// <summary> /// <summary>
/// Set true to import the computation graph instead of building it. /// Set true to import the computation graph instead of building it.
/// </summary> /// </summary>
bool ImportGraph { get; set; }
bool IsImportingGraph { get; set; }


string Name { get; } string Name { get; }


bool Run();

/// <summary> /// <summary>
/// Build dataflow graph, train and predict /// Build dataflow graph, train and predict
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
bool Run();
bool Train();

bool Predict();

Graph ImportGraph();

Graph BuildGraph();

/// <summary> /// <summary>
/// Prepare dataset /// Prepare dataset
/// </summary> /// </summary>


+ 21
- 3
test/TensorFlowNET.Examples/ImageProcess/ImageBackgroundRemoval.cs View File

@@ -15,10 +15,8 @@ namespace TensorFlowNET.Examples.ImageProcess
/// </summary> /// </summary>
public class ImageBackgroundRemoval : IExample public class ImageBackgroundRemoval : IExample
{ {
public int Priority => 15;

public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;


public string Name => "Image Background Removal"; public string Name => "Image Background Removal";


@@ -59,5 +57,25 @@ namespace TensorFlowNET.Examples.ImageProcess
Web.Download(url, modelDir, fileName); Web.Download(url, modelDir, fileName);
Compress.ExtractTGZ(Path.Join(modelDir, fileName), modelDir);*/ Compress.ExtractTGZ(Path.Join(modelDir, fileName), modelDir);*/
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 2
test/TensorFlowNET.Examples/ImageProcess/ImageRecognitionInception.cs View File

@@ -20,10 +20,9 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class ImageRecognitionInception : IExample public class ImageRecognitionInception : IExample
{ {
public int Priority => 7;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Image Recognition Inception"; public string Name => "Image Recognition Inception";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;




string dir = "ImageRecognitionInception"; string dir = "ImageRecognitionInception";
@@ -115,5 +114,25 @@ namespace TensorFlowNET.Examples
file_ndarrays.Add(nd); file_ndarrays.Add(nd);
} }
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 2
test/TensorFlowNET.Examples/ImageProcess/InceptionArchGoogLeNet.cs View File

@@ -21,9 +21,8 @@ namespace TensorFlowNET.Examples
public class InceptionArchGoogLeNet : IExample public class InceptionArchGoogLeNet : IExample
{ {
public bool Enabled { get; set; } = false; public bool Enabled { get; set; } = false;
public int Priority => 100;
public string Name => "Inception Arch GoogLeNet"; public string Name => "Inception Arch GoogLeNet";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;




string dir = "label_image_data"; string dir = "label_image_data";
@@ -108,5 +107,25 @@ namespace TensorFlowNET.Examples
url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}"; url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}";
Utility.Web.Download(url, dir, pic); Utility.Web.Download(url, dir, pic);
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 2
test/TensorFlowNET.Examples/ImageProcess/ObjectDetection.cs View File

@@ -16,10 +16,9 @@ namespace TensorFlowNET.Examples


public class ObjectDetection : IExample public class ObjectDetection : IExample
{ {
public int Priority => 11;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Object Detection"; public string Name => "Object Detection";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


public float MIN_SCORE = 0.5f; public float MIN_SCORE = 0.5f;


@@ -145,5 +144,25 @@ namespace TensorFlowNET.Examples
} }
} }
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 1
test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs View File

@@ -26,7 +26,7 @@ namespace TensorFlowNET.Examples.ImageProcess
public int Priority => 16; public int Priority => 16;


public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;


public string Name => "Retrain Image Classifier"; public string Name => "Retrain Image Classifier";


@@ -667,5 +667,25 @@ namespace TensorFlowNET.Examples.ImageProcess


return result; return result;
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 4
- 5
test/TensorFlowNET.Examples/Program.cs View File

@@ -19,7 +19,6 @@ namespace TensorFlowNET.Examples
var examples = Assembly.GetEntryAssembly().GetTypes() var examples = Assembly.GetEntryAssembly().GetTypes()
.Where(x => x.GetInterfaces().Contains(typeof(IExample))) .Where(x => x.GetInterfaces().Contains(typeof(IExample)))
.Select(x => (IExample)Activator.CreateInstance(x)) .Select(x => (IExample)Activator.CreateInstance(x))
.OrderBy(x => x.Priority)
.ToArray(); .ToArray();


Console.WriteLine($"TensorFlow v{tf.VERSION}", Color.Yellow); Console.WriteLine($"TensorFlow v{tf.VERSION}", Color.Yellow);
@@ -42,18 +41,18 @@ namespace TensorFlowNET.Examples
sw.Stop(); sw.Stop();


if (isSuccess) if (isSuccess)
success.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s");
success.Add($"Example: {example.Name} in {sw.Elapsed.TotalSeconds}s");
else else
errors.Add($"Example {example.Priority}: {example.Name} in {sw.Elapsed.TotalSeconds}s");
errors.Add($"Example: {example.Name} in {sw.Elapsed.TotalSeconds}s");
} }
else else
{ {
disabled.Add($"Example {example.Priority}: {example.Name} in {sw.ElapsedMilliseconds}ms");
disabled.Add($"Example: {example.Name} in {sw.ElapsedMilliseconds}ms");
} }
} }
catch (Exception ex) catch (Exception ex)
{ {
errors.Add($"Example {example.Priority}: {example.Name}");
errors.Add($"Example: {example.Name}");
Console.WriteLine(ex); Console.WriteLine(ex);
} }


+ 21
- 2
test/TensorFlowNET.Examples/TextProcess/BinaryTextClassification.cs View File

@@ -17,10 +17,9 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class BinaryTextClassification : IExample public class BinaryTextClassification : IExample
{ {
public int Priority => 9;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Binary Text Classification"; public string Name => "Binary Text Classification";
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;


string dir = "binary_text_classification"; string dir = "binary_text_classification";
string dataFile = "imdb.zip"; string dataFile = "imdb.zip";
@@ -138,5 +137,25 @@ namespace TensorFlowNET.Examples


return result; return result;
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 23
- 4
test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs View File

@@ -21,11 +21,10 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class CnnTextClassification : IExample public class CnnTextClassification : IExample
{ {
public int Priority => 17;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "CNN Text Classification"; public string Name => "CNN Text Classification";
public int? DataLimit = null; public int? DataLimit = null;
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = false;


private string dataDir = "word_cnn"; private string dataDir = "word_cnn";
private string dataFileName = "dbpedia_csv.tar.gz"; private string dataFileName = "dbpedia_csv.tar.gz";
@@ -49,7 +48,7 @@ namespace TensorFlowNET.Examples
var graph = tf.Graph().as_default(); var graph = tf.Graph().as_default();
return with(tf.Session(graph), sess => return with(tf.Session(graph), sess =>
{ {
if (ImportGraph)
if (IsImportingGraph)
return RunWithImportedGraph(sess, graph); return RunWithImportedGraph(sess, graph);
else else
return RunWithBuiltGraph(sess, graph); return RunWithBuiltGraph(sess, graph);
@@ -222,7 +221,7 @@ namespace TensorFlowNET.Examples
Web.Download(url, dataDir, "dbpedia_subset.zip"); Web.Download(url, dataDir, "dbpedia_subset.zip");
Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv")); Compress.UnZip(Path.Combine(dataDir, "dbpedia_subset.zip"), Path.Combine(dataDir, "dbpedia_csv"));


if (ImportGraph)
if (IsImportingGraph)
{ {
// download graph meta data // download graph meta data
var meta_file = "word_cnn.meta"; var meta_file = "word_cnn.meta";
@@ -237,5 +236,25 @@ namespace TensorFlowNET.Examples
Web.Download(url, "graph", meta_file); Web.Download(url, "graph", meta_file);
} }
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 3
test/TensorFlowNET.Examples/TextProcess/NER/BiLstmCrfNer.cs View File

@@ -14,10 +14,8 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class BiLstmCrfNer : IExample public class BiLstmCrfNer : IExample
{ {
public int Priority => 101;

public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


public string Name => "bi-LSTM + CRF NER"; public string Name => "bi-LSTM + CRF NER";


@@ -35,5 +33,25 @@ namespace TensorFlowNET.Examples
hp.filepath_tags = Path.Combine(hp.data_root_dir, "vocab.tags.txt"); hp.filepath_tags = Path.Combine(hp.data_root_dir, "vocab.tags.txt");
hp.filepath_glove = Path.Combine(hp.data_root_dir, "glove.npz"); hp.filepath_glove = Path.Combine(hp.data_root_dir, "glove.npz");
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 3
test/TensorFlowNET.Examples/TextProcess/NER/CRF.cs View File

@@ -15,10 +15,8 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class CRF : IExample public class CRF : IExample
{ {
public int Priority => 13;

public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;


public string Name => "CRF"; public string Name => "CRF";


@@ -31,5 +29,25 @@ namespace TensorFlowNET.Examples
{ {


} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 3
test/TensorFlowNET.Examples/TextProcess/NER/LstmCrfNer.cs View File

@@ -20,10 +20,8 @@ namespace TensorFlowNET.Examples.Text.NER
/// </summary> /// </summary>
public class LstmCrfNer : IExample public class LstmCrfNer : IExample
{ {
public int Priority => 14;

public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;


public string Name => "LSTM + CRF NER"; public string Name => "LSTM + CRF NER";


@@ -208,5 +206,25 @@ namespace TensorFlowNET.Examples.Text.NER
Web.Download(url, "graph", meta_file); Web.Download(url, "graph", meta_file);


} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 22
- 3
test/TensorFlowNET.Examples/TextProcess/NamedEntityRecognition.cs View File

@@ -11,13 +11,12 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class NamedEntityRecognition : IExample public class NamedEntityRecognition : IExample
{ {
public int Priority => 100;
public bool Enabled { get; set; } = false; public bool Enabled { get; set; } = false;
public string Name => "NER"; public string Name => "NER";
public bool ImportGraph { get; set; } = false;
public bool IsImportingGraph { get; set; } = false;




public bool Run()
public bool Train()
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }
@@ -26,5 +25,25 @@ namespace TensorFlowNET.Examples
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }

public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Run()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 24
- 5
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -21,11 +21,10 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class TextClassificationTrain : IExample public class TextClassificationTrain : IExample
{ {
public int Priority => 100;
public bool Enabled { get; set; } = false; public bool Enabled { get; set; } = false;
public string Name => "Text Classification"; public string Name => "Text Classification";
public int? DataLimit = null; public int? DataLimit = null;
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;
public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia public bool UseSubset = false; // <----- set this true to use a limited subset of dbpedia


private string dataDir = "text_classification"; private string dataDir = "text_classification";
@@ -51,7 +50,7 @@ namespace TensorFlowNET.Examples
var graph = tf.Graph().as_default(); var graph = tf.Graph().as_default();
return with(tf.Session(graph), sess => return with(tf.Session(graph), sess =>
{ {
if (ImportGraph)
if (IsImportingGraph)
return RunWithImportedGraph(sess, graph); return RunWithImportedGraph(sess, graph);
else else
return RunWithBuiltGraph(sess, graph); return RunWithBuiltGraph(sess, graph);
@@ -255,7 +254,7 @@ namespace TensorFlowNET.Examples
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir); Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
} }


if (ImportGraph)
if (IsImportingGraph)
{ {
// download graph meta data // download graph meta data
var meta_file = model_name + ".meta"; var meta_file = model_name + ".meta";
@@ -269,6 +268,26 @@ namespace TensorFlowNET.Examples
var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; var url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
Web.Download(url, "graph", meta_file); Web.Download(url, "graph", meta_file);
} }
}
}
public Graph ImportGraph()
{
throw new NotImplementedException();
}
public Graph BuildGraph()
{
throw new NotImplementedException();
}
public bool Train()
{
throw new NotImplementedException();
}
public bool Predict()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 2
test/TensorFlowNET.Examples/TextProcess/Word2Vec.cs View File

@@ -16,10 +16,9 @@ namespace TensorFlowNET.Examples
/// </summary> /// </summary>
public class Word2Vec : IExample public class Word2Vec : IExample
{ {
public int Priority => 12;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Word2Vec"; public string Name => "Word2Vec";
public bool ImportGraph { get; set; } = true;
public bool IsImportingGraph { get; set; } = true;


// Training Parameters // Training Parameters
float learning_rate = 0.1f; float learning_rate = 0.1f;
@@ -205,6 +204,26 @@ namespace TensorFlowNET.Examples
print($"Most common words: {string.Join(", ", word2id.Take(10))}"); print($"Most common words: {string.Join(", ", word2id.Take(10))}");
} }


public Graph ImportGraph()
{
throw new NotImplementedException();
}

public Graph BuildGraph()
{
throw new NotImplementedException();
}

public bool Train()
{
throw new NotImplementedException();
}

public bool Predict()
{
throw new NotImplementedException();
}

private class WordId private class WordId
{ {
public string Word { get; set; } public string Word { get; set; }


+ 15
- 15
test/TensorFlowNET.UnitTest/ExamplesTests/ExamplesTest.cs View File

@@ -14,21 +14,21 @@ namespace TensorFlowNET.ExamplesTests
public void BasicOperations() public void BasicOperations()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new BasicOperations() { Enabled = true }.Run();
new BasicOperations() { Enabled = true }.Train();
} }
[TestMethod] [TestMethod]
public void HelloWorld() public void HelloWorld()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new HelloWorld() { Enabled = true }.Run();
new HelloWorld() { Enabled = true }.Train();
} }
[TestMethod] [TestMethod]
public void ImageRecognition() public void ImageRecognition()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new HelloWorld() { Enabled = true }.Run();
new HelloWorld() { Enabled = true }.Train();
} }
[Ignore] [Ignore]
@@ -36,28 +36,28 @@ namespace TensorFlowNET.ExamplesTests
public void InceptionArchGoogLeNet() public void InceptionArchGoogLeNet()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new InceptionArchGoogLeNet() { Enabled = true }.Run();
new InceptionArchGoogLeNet() { Enabled = true }.Train();
} }
[TestMethod] [TestMethod]
public void KMeansClustering() public void KMeansClustering()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new KMeansClustering() { Enabled = true, ImportGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Run();
new KMeansClustering() { Enabled = true, IsImportingGraph = true, train_size = 500, validation_size = 100, test_size = 100, batch_size =100 }.Train();
} }
[TestMethod] [TestMethod]
public void LinearRegression() public void LinearRegression()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new LinearRegression() { Enabled = true }.Run();
new LinearRegression() { Enabled = true }.Train();
} }
[TestMethod] [TestMethod]
public void LogisticRegression() public void LogisticRegression()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Run();
new LogisticRegression() { Enabled = true, training_epochs=10, train_size = 500, validation_size = 100, test_size = 100 }.Train();
} }
[Ignore] [Ignore]
@@ -65,7 +65,7 @@ namespace TensorFlowNET.ExamplesTests
public void NaiveBayesClassifier() public void NaiveBayesClassifier()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new NaiveBayesClassifier() { Enabled = false }.Run();
new NaiveBayesClassifier() { Enabled = false }.Train();
} }
[Ignore] [Ignore]
@@ -73,14 +73,14 @@ namespace TensorFlowNET.ExamplesTests
public void NamedEntityRecognition() public void NamedEntityRecognition()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new NamedEntityRecognition() { Enabled = true }.Run();
new NamedEntityRecognition() { Enabled = true }.Train();
} }
[TestMethod] [TestMethod]
public void NearestNeighbor() public void NearestNeighbor()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Run();
new NearestNeighbor() { Enabled = true, TrainSize = 500, ValidationSize = 100, TestSize = 100 }.Train();
} }
[Ignore] [Ignore]
@@ -88,7 +88,7 @@ namespace TensorFlowNET.ExamplesTests
public void TextClassificationTrain() public void TextClassificationTrain()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Run();
new TextClassificationTrain() { Enabled = true, DataLimit=100 }.Train();
} }
[Ignore] [Ignore]
@@ -96,21 +96,21 @@ namespace TensorFlowNET.ExamplesTests
public void TextClassificationWithMovieReviews() public void TextClassificationWithMovieReviews()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
new BinaryTextClassification() { Enabled = true }.Run();
new BinaryTextClassification() { Enabled = true }.Train();
} }
[TestMethod] [TestMethod]
public void NeuralNetXor() public void NeuralNetXor()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
Assert.IsTrue(new NeuralNetXor() { Enabled = true, ImportGraph = false }.Run());
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = false }.Train());
} }
[TestMethod] [TestMethod]
public void NeuralNetXor_ImportedGraph() public void NeuralNetXor_ImportedGraph()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
Assert.IsTrue(new NeuralNetXor() { Enabled = true, ImportGraph = true }.Run());
Assert.IsTrue(new NeuralNetXor() { Enabled = true, IsImportingGraph = true }.Train());
} }
@@ -118,7 +118,7 @@ namespace TensorFlowNET.ExamplesTests
public void ObjectDetection() public void ObjectDetection()
{ {
tf.Graph().as_default(); tf.Graph().as_default();
Assert.IsTrue(new ObjectDetection() { Enabled = true, ImportGraph = true }.Run());
Assert.IsTrue(new ObjectDetection() { Enabled = true, IsImportingGraph = true }.Train());
} }
} }
} }

Loading…
Cancel
Save