Browse Source

fix result of session.run

tags/v0.9
Oceania2018 6 years ago
parent
commit
625368abec
17 changed files with 106 additions and 44 deletions
  1. +17
    -2
      src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
  2. +8
    -5
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  3. +5
    -1
      test/TensorFlowNET.Examples/BasicEagerApi.cs
  4. +6
    -8
      test/TensorFlowNET.Examples/BasicOperations.cs
  5. +6
    -8
      test/TensorFlowNET.Examples/HelloWorld.cs
  6. +2
    -1
      test/TensorFlowNET.Examples/IExample.cs
  7. +6
    -1
      test/TensorFlowNET.Examples/ImageRecognition.cs
  8. +4
    -1
      test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs
  9. +8
    -3
      test/TensorFlowNET.Examples/LinearRegression.cs
  10. +4
    -3
      test/TensorFlowNET.Examples/LogisticRegression.cs
  11. +4
    -1
      test/TensorFlowNET.Examples/MetaGraph.cs
  12. +4
    -2
      test/TensorFlowNET.Examples/NaiveBayesClassifier.cs
  13. +2
    -1
      test/TensorFlowNET.Examples/NamedEntityRecognition.cs
  14. +21
    -4
      test/TensorFlowNET.Examples/Program.cs
  15. +1
    -0
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  16. +4
    -2
      test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
  17. +4
    -1
      test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs

+ 17
- 2
src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs View File

@@ -43,8 +43,23 @@ namespace Tensorflow
case NDArray value:
result = value;
break;
case float fVal:
result = fVal;
case short value:
result = value;
break;
case int value:
result = value;
break;
case long value:
result = value;
break;
case float value:
result = value;
break;
case double value:
result = value;
break;
case string value:
result = value;
break;
default:
break;


+ 8
- 5
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -58,11 +58,7 @@ namespace Tensorflow
{
var value = tensor_values[j];
j += 1;
if (value.ndim == 2)
{
full_values.Add(value[0]);
}
else
if (value.ndim == 0)
{
switch (value.dtype.Name)
{
@@ -75,8 +71,15 @@ namespace Tensorflow
case "Double":
full_values.Add(value.Data<double>(0));
break;
case "String":
full_values.Add(value.Data<string>(0));
break;
}
}
else
{
full_values.Add(value[np.arange(1)]);
}
}
i += 1;
}


+ 5
- 1
test/TensorFlowNET.Examples/BasicEagerApi.cs View File

@@ -11,9 +11,11 @@ namespace TensorFlowNET.Examples
/// </summary>
public class BasicEagerApi : IExample
{
public bool Enabled => false;

private Tensor a, b, c, d;

public void Run()
public bool Run()
{
// Set Eager API
Console.WriteLine("Setting Eager mode...");
@@ -34,6 +36,8 @@ namespace TensorFlowNET.Examples
Console.WriteLine($"a * b = {d}");

// Full compatibility with Numpy

return true;
}

public void PrepareData()


+ 6
- 8
test/TensorFlowNET.Examples/BasicOperations.cs View File

@@ -10,11 +10,12 @@ namespace TensorFlowNET.Examples
/// Basic Operations example using TensorFlow library.
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/1_Introduction/basic_operations.py
/// </summary>
public class BasicOperations : IExample
public class BasicOperations : Python, IExample
{
public bool Enabled => true;
private Session sess;

public void Run()
public bool Run()
{
// Basic constant operations
// The value returned by the constructor represents the output
@@ -86,15 +87,12 @@ namespace TensorFlowNET.Examples
// graph: the two constants and matmul.
//
// The output of the op is returned in 'result' as a numpy `ndarray` object.
using (sess = tf.Session())
return with(tf.Session(), sess =>
{
var result = sess.run(product);
Console.WriteLine(result.ToString()); // ==> [[ 12.]]
if (result.Data<int>()[0] != 12)
{
throw new ValueError("BasicOperations");
}
}
return result.Data<int>()[0] == 12;
});
}

public void PrepareData()


+ 6
- 8
test/TensorFlowNET.Examples/HelloWorld.cs View File

@@ -9,9 +9,10 @@ namespace TensorFlowNET.Examples
/// Simple hello world using TensorFlow
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/1_Introduction/helloworld.py
/// </summary>
public class HelloWorld : IExample
public class HelloWorld : Python, IExample
{
public void Run()
public bool Enabled => true;
public bool Run()
{
/* Create a Constant op
The op is added as a node to the default graph.
@@ -22,16 +23,13 @@ namespace TensorFlowNET.Examples
var hello = tf.constant(str);

// Start tf session
using (var sess = tf.Session())
return with(tf.Session(), sess =>
{
// Run the op
var result = sess.run(hello);
Console.WriteLine(result.ToString());
if(!result.ToString().Equals(str))
{
throw new ValueError("HelloWorld example acts in unexpected way.");
}
}
return result.ToString().Equals(str);
});
}

public void PrepareData()


+ 2
- 1
test/TensorFlowNET.Examples/IExample.cs View File

@@ -10,7 +10,8 @@ namespace TensorFlowNET.Examples
/// </summary>
public interface IExample
{
void Run();
bool Enabled { get; }
bool Run();
void PrepareData();
}
}

+ 6
- 1
test/TensorFlowNET.Examples/ImageRecognition.cs View File

@@ -12,12 +12,14 @@ namespace TensorFlowNET.Examples
{
public class ImageRecognition : Python, IExample
{
public bool Enabled => true;

string dir = "ImageRecognition";
string pbFile = "tensorflow_inception_graph.pb";
string labelFile = "imagenet_comp_graph_label_strings.txt";
string picFile = "grace_hopper.jpg";

public void Run()
public bool Run()
{
PrepareData();

@@ -54,7 +56,10 @@ namespace TensorFlowNET.Examples
});

Console.WriteLine($"{picFile}: {labels[idx]} {propability}");
return labels[idx].Equals("military uniform");
}

return false;
}

private NDArray ReadTensorFromImageFile(string file_name,


+ 4
- 1
test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs View File

@@ -19,6 +19,7 @@ namespace TensorFlowNET.Examples
/// </summary>
public class InceptionArchGoogLeNet : Python, IExample
{
public bool Enabled => false;
string dir = "label_image_data";
string pbFile = "inception_v3_2016_08_28_frozen.pb";
string labelFile = "imagenet_slim_labels.txt";
@@ -30,7 +31,7 @@ namespace TensorFlowNET.Examples
string input_name = "import/input";
string output_name = "import/InceptionV3/Predictions/Reshape_1";

public void Run()
public bool Run()
{
PrepareData();

@@ -60,6 +61,8 @@ namespace TensorFlowNET.Examples

foreach (float idx in top_k)
Console.WriteLine($"{picFile}: {idx} {labels[(int)idx]}, {results[(int)idx]}");

return true;
}

private NDArray ReadTensorFromImageFile(string file_name,


+ 8
- 3
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -12,6 +12,8 @@ namespace TensorFlowNET.Examples
/// </summary>
public class LinearRegression : Python, IExample
{
public bool Enabled => true;

NumPyRandom rng = np.random;

// Parameters
@@ -22,7 +24,7 @@ namespace TensorFlowNET.Examples
NDArray train_X, train_Y;
int n_samples;

public void Run()
public bool Run()
{
// Training Data
PrepareData();
@@ -52,7 +54,7 @@ namespace TensorFlowNET.Examples
var init = tf.global_variables_initializer();

// Start training
with(tf.Session(), sess =>
return with(tf.Session(), sess =>
{
// Run the initializer
sess.run(init);
@@ -91,7 +93,10 @@ namespace TensorFlowNET.Examples
new FeedItem(X, test_X),
new FeedItem(Y, test_Y));
Console.WriteLine($"Testing cost={testing_cost}");
Console.WriteLine($"Absolute mean square loss difference: {Math.Abs((float)training_cost - (float)testing_cost)}");
var diff = Math.Abs((float)training_cost - (float)testing_cost);
Console.WriteLine($"Absolute mean square loss difference: {diff}");

return diff < 0.01;
});
}



+ 4
- 3
test/TensorFlowNET.Examples/LogisticRegression.cs View File

@@ -17,6 +17,7 @@ namespace TensorFlowNET.Examples
/// </summary>
public class LogisticRegression : Python, IExample
{
public bool Enabled => true;
private float learning_rate = 0.01f;
private int training_epochs = 10;
private int batch_size = 100;
@@ -24,7 +25,7 @@ namespace TensorFlowNET.Examples

Datasets mnist;

public void Run()
public bool Run()
{
PrepareData();

@@ -48,7 +49,7 @@ namespace TensorFlowNET.Examples
// Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer();

with(tf.Session(), sess =>
return with(tf.Session(), sess =>
{
// Run the initializer
@@ -88,7 +89,7 @@ namespace TensorFlowNET.Examples
float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels));
print($"Accuracy: {acc.ToString("F4")}");

Predict();
return acc > 0.9;
});
}



+ 4
- 1
test/TensorFlowNET.Examples/MetaGraph.cs View File

@@ -9,9 +9,12 @@ namespace TensorFlowNET.Examples
{
public class MetaGraph : Python, IExample
{
public void Run()
public bool Enabled => false;

public bool Run()
{
ImportMetaGraph("my-save-dir/");
return false;
}

private void ImportMetaGraph(string dir)


+ 4
- 2
test/TensorFlowNET.Examples/NaiveBayesClassifier.cs View File

@@ -11,15 +11,17 @@ namespace TensorFlowNET.Examples
/// https://github.com/nicolov/naive_bayes_tensorflow
/// </summary>
public class NaiveBayesClassifier : Python, IExample
{
{
public bool Enabled => false;
public Normal dist { get; set; }
public void Run()
public bool Run()
{
np.array(1.0f, 1.0f);
var X = np.array(new float[][] { new float[] { 1.0f, 1.0f }, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, });
var y = np.array(0,0,1,1,2,2);
fit(X, y);
// Create a regular grid and classify each point
return false;
}

public void fit(NDArray X, NDArray y)


+ 2
- 1
test/TensorFlowNET.Examples/NamedEntityRecognition.cs View File

@@ -10,7 +10,8 @@ namespace TensorFlowNET.Examples
/// </summary>
public class NamedEntityRecognition : Python, IExample
{
public void Run()
public bool Enabled => false;
public bool Run()
{
throw new NotImplementedException();
}


+ 21
- 4
test/TensorFlowNET.Examples/Program.cs View File

@@ -1,6 +1,9 @@
using System;
using System.Collections.Generic;
using System.Drawing;
using System.Linq;
using System.Reflection;
using Console = Colorful.Console;

namespace TensorFlowNET.Examples
{
@@ -9,27 +12,41 @@ namespace TensorFlowNET.Examples
static void Main(string[] args)
{
var assembly = Assembly.GetEntryAssembly();
foreach(Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample))))
var errors = new List<string>();
var success = new List<string>();
var disabled = new List<string>();

foreach (Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample))))
{
if (args.Length > 0 && !args.Contains(type.Name))
continue;

Console.WriteLine($"{DateTime.UtcNow} Starting {type.Name}");
Console.WriteLine($"{DateTime.UtcNow} Starting {type.Name}", Color.Tan);

var example = (IExample)Activator.CreateInstance(type);

try
{
example.Run();
if (example.Enabled)
if (example.Run())
success.Add(type.Name);
else
errors.Add(type.Name);
else
disabled.Add(type.Name);
}
catch (Exception ex)
{
Console.WriteLine(ex);
}

Console.WriteLine($"{DateTime.UtcNow} Completed {type.Name}");
Console.WriteLine($"{DateTime.UtcNow} Completed {type.Name}", Color.Tan);
}

success.ForEach(x => Console.WriteLine($"{x} example is OK!", Color.Green));
disabled.ForEach(x => Console.WriteLine($"{x} example is Disabled!", Color.Tan));
errors.ForEach(x => Console.WriteLine($"{x} example is Failed!", Color.Red));

Console.ReadLine();
}
}


+ 1
- 0
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -6,6 +6,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Colorful.Console" Version="1.2.9" />
<PackageReference Include="Newtonsoft.Json" Version="12.0.1" />
<PackageReference Include="NumSharp" Version="0.8.1" />
<PackageReference Include="SharpZipLib" Version="1.1.0" />


+ 4
- 2
test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs View File

@@ -15,13 +15,14 @@ namespace TensorFlowNET.Examples.CnnTextClassification
/// </summary>
public class TextClassificationTrain : Python, IExample
{
public bool Enabled => false;
private string dataDir = "text_classification";
private string dataFileName = "dbpedia_csv.tar.gz";

private const int CHAR_MAX_LEN = 1014;
private const int NUM_CLASS = 2;

public void Run()
public bool Run()
{
PrepareData();
Console.WriteLine("Building dataset...");
@@ -29,9 +30,10 @@ namespace TensorFlowNET.Examples.CnnTextClassification

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);

with(tf.Session(), sess =>
return with(tf.Session(), sess =>
{
new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
return false;
});
}



+ 4
- 1
test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs View File

@@ -11,11 +11,12 @@ namespace TensorFlowNET.Examples
{
public class TextClassificationWithMovieReviews : Python, IExample
{
public bool Enabled => false;
string dir = "text_classification_with_movie_reviews";
string dataFile = "imdb.zip";
NDArray train_data, train_labels, test_data, test_labels;

public void Run()
public bool Run()
{
PrepareData();

@@ -39,6 +40,8 @@ namespace TensorFlowNET.Examples

var model = keras.Sequential();
model.add(keras.layers.Embedding(vocab_size, 16));

return false;
}

public void PrepareData()


Loading…
Cancel
Save