From 422bacfd80704895dcbaa8ac340eb307aca2a700 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 6 Jan 2019 00:32:58 -0600 Subject: [PATCH] fix OperationsTest.addInPlaceholder. --- src/TensorFlowNET.Core/APIs/tf.math.cs | 19 ++++++++++++++ .../Operations/gen_array_ops.cs | 10 +------ .../Sessions/BaseSession.cs | 6 +++-- src/TensorFlowNET.Core/Tensors/dtypes.cs | 15 ++--------- src/TensorFlowNET.Core/tf.cs | 6 +---- .../TensorFlowNET.Examples/BasicOperations.cs | 26 +++++++++++++++++-- test/TensorFlowNET.UnitTest/CSession.cs | 2 +- test/TensorFlowNET.UnitTest/OperationsTest.cs | 3 ++- 8 files changed, 54 insertions(+), 33 deletions(-) create mode 100644 src/TensorFlowNET.Core/APIs/tf.math.cs diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs new file mode 100644 index 00000000..08aa39bb --- /dev/null +++ b/src/TensorFlowNET.Core/APIs/tf.math.cs @@ -0,0 +1,19 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow +{ + public static partial class tf + { + public static unsafe Tensor add(Tensor a, Tensor b) + { + return gen_math_ops.add(a, b); + } + + public static unsafe Tensor multiply(Tensor x, Tensor y) + { + return gen_math_ops.mul(x, y); + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 8330fe28..5e727d9f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -12,13 +12,6 @@ namespace Tensorflow public static Tensor placeholder(TF_DataType dtype, TensorShape shape = null) { - /*var g = ops.get_default_graph(); - var op = new Operation(g, "Placeholder", "feed"); - - var tensor = new Tensor(op, 0, dtype); - - return tensor;*/ - var keywords = new Dictionary(); keywords.Add("dtype", dtype); keywords.Add("shape", shape); @@ -31,8 +24,7 @@ namespace Tensorflow _attrs["dtype"] = _op.get_attr("dtype"); _attrs["shape"] = _op.get_attr("shape"); - var tensor = new Tensor(_op, 0, dtype); - return tensor; + return new Tensor(_op, 0, dtype); } } } diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 0998bb7d..47fda701 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -105,8 +105,8 @@ namespace Tensorflow c_api.TF_SessionRun(_session, run_options: null, inputs: feed_dict.Select(f => f.Key).ToArray(), - input_values: new IntPtr[] { }, - ninputs: 0, + input_values: feed_dict.Select(f => (IntPtr)f.Value).ToArray(), + ninputs: feed_dict.Length, outputs: fetch_list, output_values: output_values, noutputs: fetch_list.Length, @@ -115,6 +115,8 @@ namespace Tensorflow run_metadata: IntPtr.Zero, status: status); + status.Check(true); + object[] result = new object[fetch_list.Length]; for (int i = 0; i < fetch_list.Length; i++) diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index d8f4bcba..ff5eb5eb 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -37,20 +37,9 @@ namespace Tensorflow switch (type) { - case TF_DataType.TF_INT32: - dtype = DataType.DtInt32; - break; - case TF_DataType.TF_FLOAT: - dtype = DataType.DtFloat; - break; - case TF_DataType.TF_DOUBLE: - dtype = DataType.DtDouble; - break; - case TF_DataType.TF_STRING: - dtype = DataType.DtString; - break; default: - throw new Exception("Not Implemented"); + Enum.TryParse(((int)type).ToString(), out dtype); + break; } return dtype; diff --git a/src/TensorFlowNET.Core/tf.cs b/src/TensorFlowNET.Core/tf.cs index 76bbd28e..fbdaf470 100644 --- a/src/TensorFlowNET.Core/tf.cs +++ b/src/TensorFlowNET.Core/tf.cs @@ -10,6 +10,7 @@ namespace Tensorflow { public static partial class tf { + public static TF_DataType int16 = TF_DataType.TF_INT16; public static TF_DataType float32 = TF_DataType.TF_FLOAT; public static TF_DataType chars = TF_DataType.TF_STRING; @@ -22,11 +23,6 @@ namespace Tensorflow return new RefVariable(data, dtype); } - public static unsafe Tensor add(Tensor a, Tensor b) - { - return gen_math_ops.add(a, b); - } - public static unsafe Tensor placeholder(TF_DataType dtype, TensorShape shape = null) { return gen_array_ops.placeholder(dtype, shape); diff --git a/test/TensorFlowNET.Examples/BasicOperations.cs b/test/TensorFlowNET.Examples/BasicOperations.cs index 9ec287f9..c43effd2 100644 --- a/test/TensorFlowNET.Examples/BasicOperations.cs +++ b/test/TensorFlowNET.Examples/BasicOperations.cs @@ -11,6 +11,8 @@ namespace TensorFlowNET.Examples /// public class BasicOperations : IExample { + private Session sess; + public void Run() { // Basic constant operations @@ -18,14 +20,34 @@ namespace TensorFlowNET.Examples // of the Constant op. var a = tf.constant(2); var b = tf.constant(3); - var c = a * b; + // Launch the default graph. - using (var sess = tf.Session()) + using (sess = tf.Session()) { Console.WriteLine("a=2, b=3"); Console.WriteLine($"Addition with constants: {sess.run(a + b)}"); Console.WriteLine($"Multiplication with constants: {sess.run(a * b)}"); } + + // Basic Operations with variable as graph input + // The value returned by the constructor represents the output + // of the Variable op. (define as input when running session) + // tf Graph input + a = tf.placeholder(tf.int16); + b = tf.placeholder(tf.int16); + + // Define some operations + var add = tf.add(a, b); + var mul = tf.multiply(a, b); + + // Launch the default graph. + using(sess = tf.Session()) + { + // var feed_dict = new Dictionary + // Run every operation with variable input + // Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict: {a: 2, b: 3})}"); + // Console.WriteLine($"Multiplication with variables: {}"); + } } } } diff --git a/test/TensorFlowNET.UnitTest/CSession.cs b/test/TensorFlowNET.UnitTest/CSession.cs index ef6d469e..70e549f3 100644 --- a/test/TensorFlowNET.UnitTest/CSession.cs +++ b/test/TensorFlowNET.UnitTest/CSession.cs @@ -78,7 +78,7 @@ namespace TensorFlowNET.UnitTest var output_values_ptr = output_values_.Select(x => (IntPtr)x).ToArray(); IntPtr targets_ptr = IntPtr.Zero; - c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, 1, + c_api.TF_SessionRun(session_, null, inputs_ptr, input_values_ptr, inputs_ptr.Length, outputs_ptr, output_values_ptr, outputs_.Count, targets_ptr, targets_.Count, IntPtr.Zero, s); diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 2a71cbec..4a7fc054 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -35,7 +35,8 @@ namespace TensorFlowNET.UnitTest feed_dict.Add(a, 3.0f); feed_dict.Add(b, 2.0f); - //var o = sess.run(c, feed_dict); + var o = sess.run(c, feed_dict); + Assert.AreEqual(o, 5.0f); } }