diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index ad4a4046..afc8ce1f 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -288,6 +288,6 @@ namespace Tensorflow /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TF_UpdateEdge(IntPtr graph, TF_Output new_src, TF_Input dst, IntPtr status); + public static extern void UpdateEdge(IntPtr graph, TF_Output new_src, TF_Input dst, IntPtr status); } } diff --git a/src/TensorFlowNET.Core/Operations/Operation.Control.cs b/src/TensorFlowNET.Core/Operations/Operation.Control.cs index aaf2937c..73f9d847 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Control.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Control.cs @@ -29,7 +29,8 @@ namespace Tensorflow public void _add_control_input(Operation op) { - c_api.TF_AddControlInput(_operDesc, op); + // c_api.TF_AddControlInput(_operDesc, op); + c_api.AddControlInput(graph, _handle, op); } public void _add_control_inputs(Operation[] ops) diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index 07dab399..e0caef72 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -1,5 +1,4 @@ using Google.Protobuf.Collections; -using Newtonsoft.Json; //using Newtonsoft.Json; using System; using System.Collections.Generic; @@ -34,15 +33,15 @@ namespace Tensorflow private readonly IntPtr _operDesc; private Graph _graph; - [JsonIgnore] + //[JsonIgnore] public Graph graph => _graph; - [JsonIgnore] + //[JsonIgnore] public int _id => _id_value; - [JsonIgnore] + //[JsonIgnore] public int _id_value; public string type => OpType; - [JsonIgnore] + //[JsonIgnore] public Operation op => this; public TF_DataType dtype => TF_DataType.DtInvalid; private Status status = new Status(); @@ -289,7 +288,7 @@ namespace Tensorflow _inputs = null; // after the c_api call next time _inputs is accessed // the updated inputs are reloaded from the c_api - c_api.TF_UpdateEdge(_graph, output, input, status); + c_api.UpdateEdge(_graph, output, input, status); //var updated_inputs = inputs; } diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 93ea0bcc..ebe28114 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -42,6 +42,23 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_AddControlInput(IntPtr desc, IntPtr input); + /// + /// + /// + /// TF_Graph* + /// TF_Operation* + /// TF_Operation* + [DllImport(TensorFlowLibName)] + public static extern void AddControlInput(IntPtr graph, IntPtr op, IntPtr input); + + /// + /// + /// + /// TF_Graph* + /// TF_Operation* + [DllImport(TensorFlowLibName)] + public static extern void RemoveAllControlInputs(IntPtr graph, IntPtr op); + /// /// For inputs that take a list of tensors. /// inputs must point to TF_Output[num_inputs]. diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index d7ad32c8..f4d8727e 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -45,7 +45,6 @@ Bug memory leak issue when allocating Tensor. - diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 035f4bb2..ee165c93 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -1,5 +1,4 @@ //using Newtonsoft.Json; -using Newtonsoft.Json; using NumSharp; using System; using System.Collections.Generic; @@ -19,13 +18,13 @@ namespace Tensorflow private readonly IntPtr _handle; private int _id; - [JsonIgnore] + //[JsonIgnore] public int Id => _id; - [JsonIgnore] + //[JsonIgnore] public Graph graph => op?.graph; - [JsonIgnore] + //[JsonIgnore] public Operation op { get; } - [JsonIgnore] + //[JsonIgnore] public Tensor[] outputs => op.outputs; /// diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md index 2a72b8a5..ccb7532b 100644 --- a/tensorflowlib/README.md +++ b/tensorflowlib/README.md @@ -36,4 +36,17 @@ pacman -S git patch unzip 4. Install from local wheel file. -`pip install C:/tmp/tensorflow_pkg/tensorflow-1.13.0-cp36-cp36m-win_amd64.whl` \ No newline at end of file +`pip install C:/tmp/tensorflow_pkg/tensorflow-1.13.0-cp36-cp36m-win_amd64.whl` + +### Export more APIs + +Add more api to `c_api.h` + +```c++ +TF_CAPI_EXPORT extern void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); +TF_CAPI_EXPORT extern void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); +TF_CAPI_EXPORT extern void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); +``` + + + diff --git a/tensorflowlib/runtimes/win-x64/native/tensorflow.dll b/tensorflowlib/runtimes/win-x64/native/tensorflow.dll index e5f9d0b7..f98eb380 100644 Binary files a/tensorflowlib/runtimes/win-x64/native/tensorflow.dll and b/tensorflowlib/runtimes/win-x64/native/tensorflow.dll differ diff --git a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs index dd364149..58b5a086 100644 --- a/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs +++ b/test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs @@ -15,79 +15,37 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test public void testCondTrue() { var graph = tf.Graph().as_default(); - // tf.train.import_meta_graph("cond_test.meta"); - var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); with(tf.Session(graph), sess => { - var x = tf.constant(2, name: "x"); // graph.get_operation_by_name("Const").output; - var y = tf.constant(5, name: "y"); // graph.get_operation_by_name("Const_1").output; - var pred = tf.less(x, y); // graph.get_operation_by_name("Less").output; - - Func if_true = delegate - { - return tf.constant(2, name: "t2"); - }; - - Func if_false = delegate - { - return tf.constant(5, name: "f5"); - }; + var x = tf.constant(2, name: "x"); + var y = tf.constant(5, name: "y"); - var z = control_flow_ops.cond(pred, if_true, if_false); // graph.get_operation_by_name("cond/Merge").output + var z = control_flow_ops.cond(tf.less(x, y), + () => tf.constant(22, name: "t2"), + () => tf.constant(55, name: "f5")); - json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); int result = z.eval(sess); - assertEquals(result, 2); + assertEquals(result, 22); }); } [TestMethod] public void testCondFalse() { - /* python - * import tensorflow as tf - from tensorflow.python.framework import ops - - def if_true(): - return tf.math.multiply(x, 17) - def if_false(): - return tf.math.add(y, 23) - - with tf.Session() as sess: - x = tf.constant(2) - y = tf.constant(1) - pred = tf.math.less(x,y) - z = tf.cond(pred, if_true, if_false) - result = z.eval() - - print(result == 24) */ - var graph = tf.Graph().as_default(); - //tf.train.import_meta_graph("cond_test.meta"); - //var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); - with(tf.Session(), sess => + with(tf.Session(graph), sess => { var x = tf.constant(2, name: "x"); var y = tf.constant(1, name: "y"); - var pred = tf.less(x, y); - - Func if_true = delegate - { - return tf.constant(2, name: "t2"); - }; - - Func if_false = delegate - { - return tf.constant(1, name: "f1"); - }; - var z = control_flow_ops.cond(pred, if_true, if_false); + var z = control_flow_ops.cond(tf.less(x, y), + () => tf.constant(22, name: "t2"), + () => tf.constant(11, name: "f1")); - var json1 = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); int result = z.eval(sess); - assertEquals(result, 1); + assertEquals(result, 11); }); }