From 740ca28965624450a11e20c09996c8b0683582f8 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 29 Dec 2018 23:38:20 -0600 Subject: [PATCH] GetNodeDef --- src/TensorFlowNET.Core/Operations/Operation.cs | 1 + src/TensorFlowNET.Core/Operations/c_api.ops.cs | 3 +++ src/TensorFlowNET.Core/Operations/ops.cs | 9 ++++++--- src/TensorFlowNET.Core/Status/Status.cs | 3 ++- src/TensorFlowNET.Core/Tensors/Tensor.cs | 2 +- src/TensorFlowNET.Core/c_api_util.cs | 7 +------ test/TensorFlowNET.Examples/Program.cs | 11 ++++++++++- .../TensorFlowNET.Examples.csproj | 1 + test/TensorFlowNET.UnitTest/GraphTest.cs | 13 ++++++++----- test/TensorFlowNET.UnitTest/OperationsTest.cs | 2 +- test/TensorFlowNET.UnitTest/TensorTest.cs | 2 +- test/TensorFlowNET.UnitTest/c_test_util.cs | 13 +++++++++---- 12 files changed, 44 insertions(+), 23 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 1ab17cf8..3ab9d6f9 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -63,6 +63,7 @@ namespace Tensorflow _id_value = Graph._next_id(); _handle = ops._create_c_op(g, node_def, inputs); + NumOutputs = c_api.TF_OperationNumOutputs(_handle); _outputs = new Tensor[NumOutputs]; for (int i = 0; i < NumOutputs; i++) diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index bff7fd0c..cfde53c1 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -90,6 +90,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); + [DllImport(TensorFlowLibName)] + public static extern void TF_OperationToNodeDef(IntPtr oper, IntPtr buffer, IntPtr status); + [DllImport(TensorFlowLibName)] public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs index 28449fc2..f0eefdda 100644 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ b/src/TensorFlowNET.Core/Operations/ops.cs @@ -6,6 +6,7 @@ using System.Threading; using Tensorflow; using node_def_pb2 = Tensorflow; using Google.Protobuf; +using System.Linq; namespace Tensorflow { @@ -27,12 +28,14 @@ namespace Tensorflow var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name); // Add inputs - if(inputs != null) + if(inputs != null && inputs.Count > 0) { - foreach (var op_input in inputs) + /*foreach (var op_input in inputs) { c_api.TF_AddInput(op_desc, op_input._as_tf_output()); - } + }*/ + + c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); } var status = new Status(); diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index 84a15aec..e0d3edca 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -40,7 +40,8 @@ namespace Tensorflow { if(Code != TF_Code.TF_OK) { - throw new Exception(Message); + Console.WriteLine(Message); + // throw new Exception(Message); } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 4c4f19ff..c3a634da 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -118,7 +118,7 @@ namespace Tensorflow public TF_Output _as_tf_output() { - return c_api_util.tf_output(op, value_index); + return new TF_Output(op, value_index); } public T[] Data() diff --git a/src/TensorFlowNET.Core/c_api_util.cs b/src/TensorFlowNET.Core/c_api_util.cs index f4a918aa..4ee32805 100644 --- a/src/TensorFlowNET.Core/c_api_util.cs +++ b/src/TensorFlowNET.Core/c_api_util.cs @@ -8,12 +8,7 @@ namespace Tensorflow { public static TF_Output tf_output(IntPtr c_op, int index) { - - var ret = new TF_Output(); - ret.oper = c_op; - ret.index = index; - - return ret; + return new TF_Output(c_op, index); } } } diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs index fdb0c2bb..70adeb7d 100644 --- a/test/TensorFlowNET.Examples/Program.cs +++ b/test/TensorFlowNET.Examples/Program.cs @@ -12,7 +12,16 @@ namespace TensorFlowNET.Examples foreach(Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample)))) { var example = (IExample)Activator.CreateInstance(type); - example.Run(); + + try + { + example.Run(); + } + catch (Exception ex) + { + Console.WriteLine(ex); + Console.ReadLine(); + } } } } diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index bf59b53f..56518593 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -10,6 +10,7 @@ + diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 708e44ae..9b95820d 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -33,20 +33,23 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(0, feed.NumControlOutputs); AttrValue attr_value = null; - c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s); + Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); Assert.AreEqual(attr_value.Type, DataType.DtInt32); // Test not found errors in TF_Operation*() query functions. - // Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); - // Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); - // Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); - // Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); + Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); + Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); + //Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); + //Assert.AreEqual("Operation '' has no attr named 'missing'.", s.Message); // Make a constant oper with the scalar "3". var three = c_test_util.ScalarConst(3, graph, s); // Add oper. var add = c_test_util.Add(feed, three, graph, s); + + NodeDef node_def = null; + c_test_util.GetNodeDef(feed, ref node_def); } } } diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index c45a146e..24c0e701 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -43,7 +43,7 @@ namespace TensorFlowNET.UnitTest public void addInConstant() { var a = tf.constant(4.0f); - var b = tf.constant(5.0f); + var b = tf.placeholder(tf.float32); var c = tf.add(a, b); using (var sess = tf.Session()) diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index 4737f953..201e1411 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -139,7 +139,7 @@ namespace TensorFlowNET.UnitTest c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); //Assert.IsTrue(s.Code == TF_Code.TF_OK); - graph.Dispose(); + // graph.Dispose(); s.Dispose(); } } diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index fa845147..489226b1 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -42,6 +42,15 @@ namespace TensorFlowNET.UnitTest return s.Code == TF_Code.TF_OK; } + public static bool GetNodeDef(Operation oper, ref NodeDef node_def) + { + var s = new Status(); + var buffer = new Buffer(); + c_api.TF_OperationToNodeDef(oper, buffer, s); + + return s.Code == TF_Code.TF_OK; + } + public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) { var desc = c_api.TF_NewOperation(graph, "Placeholder", name); @@ -69,10 +78,6 @@ namespace TensorFlowNET.UnitTest c_api.TF_SetAttrType(desc, "dtype", t.dtype); op = c_api.TF_FinishOperation(desc, s); s.Check(); - if(op == null) - { - throw new Exception("c_api.TF_FinishOperation failed."); - } } public static Operation Const(Tensor t, Graph graph, Status s, string name)