From 625368abec137d49b3ed0b24403c03d5308fdd0b Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 23 Mar 2019 14:26:12 -0500 Subject: [PATCH] fix result of session.run --- .../Sessions/_ElementFetchMapper.cs | 19 ++++++++++++-- .../Sessions/_FetchHandler.cs | 13 ++++++---- test/TensorFlowNET.Examples/BasicEagerApi.cs | 6 ++++- .../TensorFlowNET.Examples/BasicOperations.cs | 14 +++++------ test/TensorFlowNET.Examples/HelloWorld.cs | 14 +++++------ test/TensorFlowNET.Examples/IExample.cs | 3 ++- .../ImageRecognition.cs | 7 +++++- .../InceptionArchGoogLeNet.cs | 5 +++- .../LinearRegression.cs | 11 +++++--- .../LogisticRegression.cs | 7 +++--- test/TensorFlowNET.Examples/MetaGraph.cs | 5 +++- .../NaiveBayesClassifier.cs | 6 +++-- .../NamedEntityRecognition.cs | 3 ++- test/TensorFlowNET.Examples/Program.cs | 25 ++++++++++++++++--- .../TensorFlowNET.Examples.csproj | 1 + .../TextClassificationTrain.cs | 6 +++-- .../TextClassificationWithMovieReviews.cs | 5 +++- 17 files changed, 106 insertions(+), 44 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs index bd86e8d8..8cf84cdf 100644 --- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs +++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs @@ -43,8 +43,23 @@ namespace Tensorflow case NDArray value: result = value; break; - case float fVal: - result = fVal; + case short value: + result = value; + break; + case int value: + result = value; + break; + case long value: + result = value; + break; + case float value: + result = value; + break; + case double value: + result = value; + break; + case string value: + result = value; break; default: break; diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 626fc6e8..b101f4bf 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -58,11 +58,7 @@ namespace Tensorflow { var value = tensor_values[j]; j += 1; - if (value.ndim == 2) - { - full_values.Add(value[0]); - } - else + if (value.ndim == 0) { switch (value.dtype.Name) { @@ -75,8 +71,15 @@ namespace Tensorflow case "Double": full_values.Add(value.Data(0)); break; + case "String": + full_values.Add(value.Data(0)); + break; } } + else + { + full_values.Add(value[np.arange(1)]); + } } i += 1; } diff --git a/test/TensorFlowNET.Examples/BasicEagerApi.cs b/test/TensorFlowNET.Examples/BasicEagerApi.cs index 3f1b325c..3440d174 100644 --- a/test/TensorFlowNET.Examples/BasicEagerApi.cs +++ b/test/TensorFlowNET.Examples/BasicEagerApi.cs @@ -11,9 +11,11 @@ namespace TensorFlowNET.Examples /// public class BasicEagerApi : IExample { + public bool Enabled => false; + private Tensor a, b, c, d; - public void Run() + public bool Run() { // Set Eager API Console.WriteLine("Setting Eager mode..."); @@ -34,6 +36,8 @@ namespace TensorFlowNET.Examples Console.WriteLine($"a * b = {d}"); // Full compatibility with Numpy + + return true; } public void PrepareData() diff --git a/test/TensorFlowNET.Examples/BasicOperations.cs b/test/TensorFlowNET.Examples/BasicOperations.cs index 3e263348..829ecc9a 100644 --- a/test/TensorFlowNET.Examples/BasicOperations.cs +++ b/test/TensorFlowNET.Examples/BasicOperations.cs @@ -10,11 +10,12 @@ namespace TensorFlowNET.Examples /// Basic Operations example using TensorFlow library. /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/1_Introduction/basic_operations.py /// - public class BasicOperations : IExample + public class BasicOperations : Python, IExample { + public bool Enabled => true; private Session sess; - public void Run() + public bool Run() { // Basic constant operations // The value returned by the constructor represents the output @@ -86,15 +87,12 @@ namespace TensorFlowNET.Examples // graph: the two constants and matmul. // // The output of the op is returned in 'result' as a numpy `ndarray` object. - using (sess = tf.Session()) + return with(tf.Session(), sess => { var result = sess.run(product); Console.WriteLine(result.ToString()); // ==> [[ 12.]] - if (result.Data()[0] != 12) - { - throw new ValueError("BasicOperations"); - } - } + return result.Data()[0] == 12; + }); } public void PrepareData() diff --git a/test/TensorFlowNET.Examples/HelloWorld.cs b/test/TensorFlowNET.Examples/HelloWorld.cs index 7c94b13c..f9b2baa8 100644 --- a/test/TensorFlowNET.Examples/HelloWorld.cs +++ b/test/TensorFlowNET.Examples/HelloWorld.cs @@ -9,9 +9,10 @@ namespace TensorFlowNET.Examples /// Simple hello world using TensorFlow /// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/1_Introduction/helloworld.py /// - public class HelloWorld : IExample + public class HelloWorld : Python, IExample { - public void Run() + public bool Enabled => true; + public bool Run() { /* Create a Constant op The op is added as a node to the default graph. @@ -22,16 +23,13 @@ namespace TensorFlowNET.Examples var hello = tf.constant(str); // Start tf session - using (var sess = tf.Session()) + return with(tf.Session(), sess => { // Run the op var result = sess.run(hello); Console.WriteLine(result.ToString()); - if(!result.ToString().Equals(str)) - { - throw new ValueError("HelloWorld example acts in unexpected way."); - } - } + return result.ToString().Equals(str); + }); } public void PrepareData() diff --git a/test/TensorFlowNET.Examples/IExample.cs b/test/TensorFlowNET.Examples/IExample.cs index d8761c36..b5c47b76 100644 --- a/test/TensorFlowNET.Examples/IExample.cs +++ b/test/TensorFlowNET.Examples/IExample.cs @@ -10,7 +10,8 @@ namespace TensorFlowNET.Examples /// public interface IExample { - void Run(); + bool Enabled { get; } + bool Run(); void PrepareData(); } } diff --git a/test/TensorFlowNET.Examples/ImageRecognition.cs b/test/TensorFlowNET.Examples/ImageRecognition.cs index 8c708c11..58e4bde3 100644 --- a/test/TensorFlowNET.Examples/ImageRecognition.cs +++ b/test/TensorFlowNET.Examples/ImageRecognition.cs @@ -12,12 +12,14 @@ namespace TensorFlowNET.Examples { public class ImageRecognition : Python, IExample { + public bool Enabled => true; + string dir = "ImageRecognition"; string pbFile = "tensorflow_inception_graph.pb"; string labelFile = "imagenet_comp_graph_label_strings.txt"; string picFile = "grace_hopper.jpg"; - public void Run() + public bool Run() { PrepareData(); @@ -54,7 +56,10 @@ namespace TensorFlowNET.Examples }); Console.WriteLine($"{picFile}: {labels[idx]} {propability}"); + return labels[idx].Equals("military uniform"); } + + return false; } private NDArray ReadTensorFromImageFile(string file_name, diff --git a/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs b/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs index 7b7b2e96..bfea1922 100644 --- a/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs +++ b/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs @@ -19,6 +19,7 @@ namespace TensorFlowNET.Examples /// public class InceptionArchGoogLeNet : Python, IExample { + public bool Enabled => false; string dir = "label_image_data"; string pbFile = "inception_v3_2016_08_28_frozen.pb"; string labelFile = "imagenet_slim_labels.txt"; @@ -30,7 +31,7 @@ namespace TensorFlowNET.Examples string input_name = "import/input"; string output_name = "import/InceptionV3/Predictions/Reshape_1"; - public void Run() + public bool Run() { PrepareData(); @@ -60,6 +61,8 @@ namespace TensorFlowNET.Examples foreach (float idx in top_k) Console.WriteLine($"{picFile}: {idx} {labels[(int)idx]}, {results[(int)idx]}"); + + return true; } private NDArray ReadTensorFromImageFile(string file_name, diff --git a/test/TensorFlowNET.Examples/LinearRegression.cs b/test/TensorFlowNET.Examples/LinearRegression.cs index 07e6fe78..efd9dc7e 100644 --- a/test/TensorFlowNET.Examples/LinearRegression.cs +++ b/test/TensorFlowNET.Examples/LinearRegression.cs @@ -12,6 +12,8 @@ namespace TensorFlowNET.Examples /// public class LinearRegression : Python, IExample { + public bool Enabled => true; + NumPyRandom rng = np.random; // Parameters @@ -22,7 +24,7 @@ namespace TensorFlowNET.Examples NDArray train_X, train_Y; int n_samples; - public void Run() + public bool Run() { // Training Data PrepareData(); @@ -52,7 +54,7 @@ namespace TensorFlowNET.Examples var init = tf.global_variables_initializer(); // Start training - with(tf.Session(), sess => + return with(tf.Session(), sess => { // Run the initializer sess.run(init); @@ -91,7 +93,10 @@ namespace TensorFlowNET.Examples new FeedItem(X, test_X), new FeedItem(Y, test_Y)); Console.WriteLine($"Testing cost={testing_cost}"); - Console.WriteLine($"Absolute mean square loss difference: {Math.Abs((float)training_cost - (float)testing_cost)}"); + var diff = Math.Abs((float)training_cost - (float)testing_cost); + Console.WriteLine($"Absolute mean square loss difference: {diff}"); + + return diff < 0.01; }); } diff --git a/test/TensorFlowNET.Examples/LogisticRegression.cs b/test/TensorFlowNET.Examples/LogisticRegression.cs index 7e2f8e3b..0d994fd3 100644 --- a/test/TensorFlowNET.Examples/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/LogisticRegression.cs @@ -17,6 +17,7 @@ namespace TensorFlowNET.Examples /// public class LogisticRegression : Python, IExample { + public bool Enabled => true; private float learning_rate = 0.01f; private int training_epochs = 10; private int batch_size = 100; @@ -24,7 +25,7 @@ namespace TensorFlowNET.Examples Datasets mnist; - public void Run() + public bool Run() { PrepareData(); @@ -48,7 +49,7 @@ namespace TensorFlowNET.Examples // Initialize the variables (i.e. assign their default value) var init = tf.global_variables_initializer(); - with(tf.Session(), sess => + return with(tf.Session(), sess => { // Run the initializer @@ -88,7 +89,7 @@ namespace TensorFlowNET.Examples float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels)); print($"Accuracy: {acc.ToString("F4")}"); - Predict(); + return acc > 0.9; }); } diff --git a/test/TensorFlowNET.Examples/MetaGraph.cs b/test/TensorFlowNET.Examples/MetaGraph.cs index 5b6f3648..52f14933 100644 --- a/test/TensorFlowNET.Examples/MetaGraph.cs +++ b/test/TensorFlowNET.Examples/MetaGraph.cs @@ -9,9 +9,12 @@ namespace TensorFlowNET.Examples { public class MetaGraph : Python, IExample { - public void Run() + public bool Enabled => false; + + public bool Run() { ImportMetaGraph("my-save-dir/"); + return false; } private void ImportMetaGraph(string dir) diff --git a/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs b/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs index c53f7c94..8fa3ec67 100644 --- a/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs +++ b/test/TensorFlowNET.Examples/NaiveBayesClassifier.cs @@ -11,15 +11,17 @@ namespace TensorFlowNET.Examples /// https://github.com/nicolov/naive_bayes_tensorflow /// public class NaiveBayesClassifier : Python, IExample - { + { + public bool Enabled => false; public Normal dist { get; set; } - public void Run() + public bool Run() { np.array(1.0f, 1.0f); var X = np.array(new float[][] { new float[] { 1.0f, 1.0f }, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, }); var y = np.array(0,0,1,1,2,2); fit(X, y); // Create a regular grid and classify each point + return false; } public void fit(NDArray X, NDArray y) diff --git a/test/TensorFlowNET.Examples/NamedEntityRecognition.cs b/test/TensorFlowNET.Examples/NamedEntityRecognition.cs index 7e056a06..9737354d 100644 --- a/test/TensorFlowNET.Examples/NamedEntityRecognition.cs +++ b/test/TensorFlowNET.Examples/NamedEntityRecognition.cs @@ -10,7 +10,8 @@ namespace TensorFlowNET.Examples /// public class NamedEntityRecognition : Python, IExample { - public void Run() + public bool Enabled => false; + public bool Run() { throw new NotImplementedException(); } diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs index 8ca5b6db..befab612 100644 --- a/test/TensorFlowNET.Examples/Program.cs +++ b/test/TensorFlowNET.Examples/Program.cs @@ -1,6 +1,9 @@ using System; +using System.Collections.Generic; +using System.Drawing; using System.Linq; using System.Reflection; +using Console = Colorful.Console; namespace TensorFlowNET.Examples { @@ -9,27 +12,41 @@ namespace TensorFlowNET.Examples static void Main(string[] args) { var assembly = Assembly.GetEntryAssembly(); - foreach(Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample)))) + var errors = new List(); + var success = new List(); + var disabled = new List(); + + foreach (Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample)))) { if (args.Length > 0 && !args.Contains(type.Name)) continue; - Console.WriteLine($"{DateTime.UtcNow} Starting {type.Name}"); + Console.WriteLine($"{DateTime.UtcNow} Starting {type.Name}", Color.Tan); var example = (IExample)Activator.CreateInstance(type); try { - example.Run(); + if (example.Enabled) + if (example.Run()) + success.Add(type.Name); + else + errors.Add(type.Name); + else + disabled.Add(type.Name); } catch (Exception ex) { Console.WriteLine(ex); } - Console.WriteLine($"{DateTime.UtcNow} Completed {type.Name}"); + Console.WriteLine($"{DateTime.UtcNow} Completed {type.Name}", Color.Tan); } + success.ForEach(x => Console.WriteLine($"{x} example is OK!", Color.Green)); + disabled.ForEach(x => Console.WriteLine($"{x} example is Disabled!", Color.Tan)); + errors.ForEach(x => Console.WriteLine($"{x} example is Failed!", Color.Red)); + Console.ReadLine(); } } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 5545097a..9d3ba5a8 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -6,6 +6,7 @@ + diff --git a/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs index 5de3018c..b5d04627 100644 --- a/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs @@ -15,13 +15,14 @@ namespace TensorFlowNET.Examples.CnnTextClassification /// public class TextClassificationTrain : Python, IExample { + public bool Enabled => false; private string dataDir = "text_classification"; private string dataFileName = "dbpedia_csv.tar.gz"; private const int CHAR_MAX_LEN = 1014; private const int NUM_CLASS = 2; - public void Run() + public bool Run() { PrepareData(); Console.WriteLine("Building dataset..."); @@ -29,9 +30,10 @@ namespace TensorFlowNET.Examples.CnnTextClassification var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); - with(tf.Session(), sess => + return with(tf.Session(), sess => { new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); + return false; }); } diff --git a/test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs b/test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs index 3941bf12..e164ffbf 100644 --- a/test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs +++ b/test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs @@ -11,11 +11,12 @@ namespace TensorFlowNET.Examples { public class TextClassificationWithMovieReviews : Python, IExample { + public bool Enabled => false; string dir = "text_classification_with_movie_reviews"; string dataFile = "imdb.zip"; NDArray train_data, train_labels, test_data, test_labels; - public void Run() + public bool Run() { PrepareData(); @@ -39,6 +40,8 @@ namespace TensorFlowNET.Examples var model = keras.Sequential(); model.add(keras.layers.Embedding(vocab_size, 16)); + + return false; } public void PrepareData()