Browse Source

Xor: import graph

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
438a0f9720
18 changed files with 87 additions and 3 deletions
  1. BIN
      graph/xor.meta
  2. +1
    -0
      test/TensorFlowNET.Examples/BasicEagerApi.cs
  3. +2
    -0
      test/TensorFlowNET.Examples/BasicOperations.cs
  4. +2
    -0
      test/TensorFlowNET.Examples/HelloWorld.cs
  5. +6
    -0
      test/TensorFlowNET.Examples/IExample.cs
  6. +2
    -0
      test/TensorFlowNET.Examples/ImageRecognition.cs
  7. +2
    -0
      test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs
  8. +1
    -0
      test/TensorFlowNET.Examples/KMeansClustering.cs
  9. +2
    -0
      test/TensorFlowNET.Examples/LinearRegression.cs
  10. +2
    -0
      test/TensorFlowNET.Examples/LogisticRegression.cs
  11. +2
    -0
      test/TensorFlowNET.Examples/MetaGraph.cs
  12. +2
    -0
      test/TensorFlowNET.Examples/NaiveBayesClassifier.cs
  13. +2
    -0
      test/TensorFlowNET.Examples/NamedEntityRecognition.cs
  14. +2
    -0
      test/TensorFlowNET.Examples/NearestNeighbor.cs
  15. +55
    -3
      test/TensorFlowNET.Examples/NeuralNetXor.cs
  16. +2
    -0
      test/TensorFlowNET.Examples/ObjectDetection.cs
  17. +1
    -0
      test/TensorFlowNET.Examples/Text/TextClassificationTrain.cs
  18. +1
    -0
      test/TensorFlowNET.Examples/Text/TextClassificationWithMovieReviews.cs

BIN
graph/xor.meta View File


+ 1
- 0
test/TensorFlowNET.Examples/BasicEagerApi.cs View File

@@ -14,6 +14,7 @@ namespace TensorFlowNET.Examples
public int Priority => 100; public int Priority => 100;
public bool Enabled { get; set; } = false; public bool Enabled { get; set; } = false;
public string Name => "Basic Eager"; public string Name => "Basic Eager";
public bool ImportGraph { get; set; } = false;


private Tensor a, b, c, d; private Tensor a, b, c, d;




+ 2
- 0
test/TensorFlowNET.Examples/BasicOperations.cs View File

@@ -15,6 +15,8 @@ namespace TensorFlowNET.Examples
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public int Priority => 2; public int Priority => 2;
public string Name => "Basic Operations"; public string Name => "Basic Operations";
public bool ImportGraph { get; set; } = false;



private Session sess; private Session sess;




+ 2
- 0
test/TensorFlowNET.Examples/HelloWorld.cs View File

@@ -14,6 +14,8 @@ namespace TensorFlowNET.Examples
public int Priority => 1; public int Priority => 1;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Hello World"; public string Name => "Hello World";
public bool ImportGraph { get; set; } = false;



public bool Run() public bool Run()
{ {


+ 6
- 0
test/TensorFlowNET.Examples/IExample.cs View File

@@ -14,11 +14,17 @@ namespace TensorFlowNET.Examples
/// running order /// running order
/// </summary> /// </summary>
int Priority { get; } int Priority { get; }

/// <summary> /// <summary>
/// True to run example /// True to run example
/// </summary> /// </summary>
bool Enabled { get; set; } bool Enabled { get; set; }


/// <summary>
/// Set true to import the computation graph instead of building it.
/// </summary>
bool ImportGraph { get; set; }

string Name { get; } string Name { get; }


/// <summary> /// <summary>


+ 2
- 0
test/TensorFlowNET.Examples/ImageRecognition.cs View File

@@ -15,6 +15,8 @@ namespace TensorFlowNET.Examples
public int Priority => 7; public int Priority => 7;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Image Recognition"; public string Name => "Image Recognition";
public bool ImportGraph { get; set; } = false;



string dir = "ImageRecognition"; string dir = "ImageRecognition";
string pbFile = "tensorflow_inception_graph.pb"; string pbFile = "tensorflow_inception_graph.pb";


+ 2
- 0
test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs View File

@@ -22,6 +22,8 @@ namespace TensorFlowNET.Examples
public bool Enabled { get; set; } = false; public bool Enabled { get; set; } = false;
public int Priority => 100; public int Priority => 100;
public string Name => "Inception Arch GoogLeNet"; public string Name => "Inception Arch GoogLeNet";
public bool ImportGraph { get; set; } = false;



string dir = "label_image_data"; string dir = "label_image_data";
string pbFile = "inception_v3_2016_08_28_frozen.pb"; string pbFile = "inception_v3_2016_08_28_frozen.pb";


+ 1
- 0
test/TensorFlowNET.Examples/KMeansClustering.cs View File

@@ -20,6 +20,7 @@ namespace TensorFlowNET.Examples
public int Priority => 8; public int Priority => 8;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "K-means Clustering"; public string Name => "K-means Clustering";
public bool ImportGraph { get; set; } = true;


public int? train_size = null; public int? train_size = null;
public int validation_size = 5000; public int validation_size = 5000;


+ 2
- 0
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -15,6 +15,8 @@ namespace TensorFlowNET.Examples
public int Priority => 3; public int Priority => 3;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Linear Regression"; public string Name => "Linear Regression";
public bool ImportGraph { get; set; } = false;



public int training_epochs = 1000; public int training_epochs = 1000;




+ 2
- 0
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -20,6 +20,8 @@ namespace TensorFlowNET.Examples
public int Priority => 4; public int Priority => 4;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Logistic Regression"; public string Name => "Logistic Regression";
public bool ImportGraph { get; set; } = false;



public int training_epochs = 10; public int training_epochs = 10;
public int? train_size = null; public int? train_size = null;


+ 2
- 0
test/TensorFlowNET.Examples/MetaGraph.cs View File

@@ -12,6 +12,8 @@ namespace TensorFlowNET.Examples
public int Priority => 100; public int Priority => 100;
public bool Enabled { get; set; } = false; public bool Enabled { get; set; } = false;
public string Name => "Meta Graph"; public string Name => "Meta Graph";
public bool ImportGraph { get; set; } = true;



public bool Run() public bool Run()
{ {


+ 2
- 0
test/TensorFlowNET.Examples/NaiveBayesClassifier.cs View File

@@ -15,6 +15,8 @@ namespace TensorFlowNET.Examples
public int Priority => 6; public int Priority => 6;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Naive Bayes Classifier"; public string Name => "Naive Bayes Classifier";
public bool ImportGraph { get; set; } = false;



public Normal dist { get; set; } public Normal dist { get; set; }
public bool Run() public bool Run()


+ 2
- 0
test/TensorFlowNET.Examples/NamedEntityRecognition.cs View File

@@ -13,6 +13,8 @@ namespace TensorFlowNET.Examples
public int Priority => 100; public int Priority => 100;
public bool Enabled { get; set; } = false; public bool Enabled { get; set; } = false;
public string Name => "NER"; public string Name => "NER";
public bool ImportGraph { get; set; } = false;



public bool Run() public bool Run()
{ {


+ 2
- 0
test/TensorFlowNET.Examples/NearestNeighbor.cs View File

@@ -22,6 +22,8 @@ namespace TensorFlowNET.Examples
public int? TrainSize = null; public int? TrainSize = null;
public int ValidationSize = 5000; public int ValidationSize = 5000;
public int? TestSize = null; public int? TestSize = null;
public bool ImportGraph { get; set; } = false;



public bool Run() public bool Run()
{ {


+ 55
- 3
test/TensorFlowNET.Examples/NeuralNetXor.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.Text; using System.Text;
using NumSharp; using NumSharp;
using Tensorflow; using Tensorflow;
using TensorFlowNET.Examples.Utility;
namespace TensorFlowNET.Examples namespace TensorFlowNET.Examples
{ {
@@ -15,6 +16,7 @@ namespace TensorFlowNET.Examples
public int Priority => 10; public int Priority => 10;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "NN XOR"; public string Name => "NN XOR";
public bool ImportGraph { get; set; } = true;
public int num_steps = 5000; public int num_steps = 5000;
@@ -38,7 +40,7 @@ namespace TensorFlowNET.Examples
// Shape [4] // Shape [4]
var predictions = tf.sigmoid(tf.squeeze(logits)); var predictions = tf.sigmoid(tf.squeeze(logits));
var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)));
var loss = tf.reduce_mean(tf.square(predictions - tf.cast(labels, tf.float32)), name:"loss");
var gs = tf.Variable(0, trainable: false); var gs = tf.Variable(0, trainable: false);
var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs); var train_op = tf.train.GradientDescentOptimizer(0.2f).minimize(loss, global_step: gs);
@@ -49,7 +51,53 @@ namespace TensorFlowNET.Examples
public bool Run() public bool Run()
{ {
PrepareData(); PrepareData();
float loss_value = 0;
if (ImportGraph)
loss_value = RunWithImportedGraph();
else
loss_value=RunWithBuiltGraph();
return loss_value < 0.0627;
}
private float RunWithImportedGraph()
{
var graph = tf.Graph().as_default();
tf.train.import_meta_graph("graph/xor.meta");
var features = graph.get_operation_by_name("Placeholder");
var labels = graph.get_operation_by_name("Placeholder_1");
Tensor loss = graph.get_operation_by_name("loss");
var init = tf.global_variables_initializer();
float loss_value = 0;
// Start tf session
with<Session>(tf.Session(graph), sess =>
{
sess.run(init);
var step = 0;
var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32);
while (step < num_steps)
{
// original python:
//_, step, loss_value = sess.run(
// [train_op, gs, loss],
// feed_dict={features: xy, labels: y_}
// )
loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_));
step++;
if (step % 1000 == 0)
Console.WriteLine($"Step {step} loss: {loss_value}");
}
Console.WriteLine($"Final loss: {loss_value}");
});
return loss_value;
}
private float RunWithBuiltGraph()
{
var graph = tf.Graph().as_default(); var graph = tf.Graph().as_default();
var features = tf.placeholder(tf.float32, new TensorShape(4, 2)); var features = tf.placeholder(tf.float32, new TensorShape(4, 2));
@@ -76,12 +124,12 @@ namespace TensorFlowNET.Examples
// ) // )
loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_)); loss_value = sess.run(loss, new FeedItem(features, data), new FeedItem(labels, y_));
step++; step++;
if (step%1000==0)
if (step % 1000 == 0)
Console.WriteLine($"Step {step} loss: {loss_value}"); Console.WriteLine($"Step {step} loss: {loss_value}");
} }
Console.WriteLine($"Final loss: {loss_value}"); Console.WriteLine($"Final loss: {loss_value}");
}); });
return loss_value < 0.0627;
return loss_value;
} }
public void PrepareData() public void PrepareData()
@@ -93,6 +141,10 @@ namespace TensorFlowNET.Examples
{0, 0 }, {0, 0 },
{0, 1 } {0, 1 }
}; };
// download graph meta data
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/xor.meta";
Web.Download(url, "graph", "kmeans.meta");
} }
} }
} }

+ 2
- 0
test/TensorFlowNET.Examples/ObjectDetection.cs View File

@@ -18,6 +18,8 @@ namespace TensorFlowNET.Examples
public int Priority => 11; public int Priority => 11;
public bool Enabled { get; set; } = true; public bool Enabled { get; set; } = true;
public string Name => "Object Detection"; public string Name => "Object Detection";
public bool ImportGraph { get; set; } = false;

public float MIN_SCORE = 0.5f; public float MIN_SCORE = 0.5f;


string modelDir = "ssd_mobilenet_v1_coco_2018_01_28"; string modelDir = "ssd_mobilenet_v1_coco_2018_01_28";


+ 1
- 0
test/TensorFlowNET.Examples/Text/TextClassificationTrain.cs View File

@@ -18,6 +18,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
public bool Enabled { get; set; }= false; public bool Enabled { get; set; }= false;
public string Name => "Text Classification"; public string Name => "Text Classification";
public int? DataLimit = null; public int? DataLimit = null;
public bool ImportGraph { get; set; } = true;


private string dataDir = "text_classification"; private string dataDir = "text_classification";
private string dataFileName = "dbpedia_csv.tar.gz"; private string dataFileName = "dbpedia_csv.tar.gz";


+ 1
- 0
test/TensorFlowNET.Examples/Text/TextClassificationWithMovieReviews.cs View File

@@ -14,6 +14,7 @@ namespace TensorFlowNET.Examples
public int Priority => 9; public int Priority => 9;
public bool Enabled { get; set; } = false; public bool Enabled { get; set; } = false;
public string Name => "Movie Reviews"; public string Name => "Movie Reviews";
public bool ImportGraph { get; set; } = true;


string dir = "text_classification_with_movie_reviews"; string dir = "text_classification_with_movie_reviews";
string dataFile = "imdb.zip"; string dataFile = "imdb.zip";


Loading…
Cancel
Save