| @@ -288,6 +288,6 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_UpdateEdge(IntPtr graph, TF_Output new_src, TF_Input dst, IntPtr status); | |||
| public static extern void UpdateEdge(IntPtr graph, TF_Output new_src, TF_Input dst, IntPtr status); | |||
| } | |||
| } | |||
| @@ -29,7 +29,8 @@ namespace Tensorflow | |||
| public void _add_control_input(Operation op) | |||
| { | |||
| c_api.TF_AddControlInput(_operDesc, op); | |||
| // c_api.TF_AddControlInput(_operDesc, op); | |||
| c_api.AddControlInput(graph, _handle, op); | |||
| } | |||
| public void _add_control_inputs(Operation[] ops) | |||
| @@ -1,5 +1,4 @@ | |||
| using Google.Protobuf.Collections; | |||
| using Newtonsoft.Json; | |||
| //using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| @@ -34,15 +33,15 @@ namespace Tensorflow | |||
| private readonly IntPtr _operDesc; | |||
| private Graph _graph; | |||
| [JsonIgnore] | |||
| //[JsonIgnore] | |||
| public Graph graph => _graph; | |||
| [JsonIgnore] | |||
| //[JsonIgnore] | |||
| public int _id => _id_value; | |||
| [JsonIgnore] | |||
| //[JsonIgnore] | |||
| public int _id_value; | |||
| public string type => OpType; | |||
| [JsonIgnore] | |||
| //[JsonIgnore] | |||
| public Operation op => this; | |||
| public TF_DataType dtype => TF_DataType.DtInvalid; | |||
| private Status status = new Status(); | |||
| @@ -289,7 +288,7 @@ namespace Tensorflow | |||
| _inputs = null; | |||
| // after the c_api call next time _inputs is accessed | |||
| // the updated inputs are reloaded from the c_api | |||
| c_api.TF_UpdateEdge(_graph, output, input, status); | |||
| c_api.UpdateEdge(_graph, output, input, status); | |||
| //var updated_inputs = inputs; | |||
| } | |||
| @@ -42,6 +42,23 @@ namespace Tensorflow | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_AddControlInput(IntPtr desc, IntPtr input); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="graph">TF_Graph*</param> | |||
| /// <param name="op">TF_Operation*</param> | |||
| /// <param name="input">TF_Operation*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void AddControlInput(IntPtr graph, IntPtr op, IntPtr input); | |||
| /// <summary> | |||
| /// | |||
| /// </summary> | |||
| /// <param name="graph">TF_Graph*</param> | |||
| /// <param name="op">TF_Operation*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void RemoveAllControlInputs(IntPtr graph, IntPtr op); | |||
| /// <summary> | |||
| /// For inputs that take a list of tensors. | |||
| /// inputs must point to TF_Output[num_inputs]. | |||
| @@ -45,7 +45,6 @@ Bug memory leak issue when allocating Tensor.</PackageReleaseNotes> | |||
| <ItemGroup> | |||
| <PackageReference Include="Google.Protobuf" Version="3.7.0" /> | |||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.1" /> | |||
| </ItemGroup> | |||
| <ItemGroup> | |||
| @@ -1,5 +1,4 @@ | |||
| //using Newtonsoft.Json; | |||
| using Newtonsoft.Json; | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| @@ -19,13 +18,13 @@ namespace Tensorflow | |||
| private readonly IntPtr _handle; | |||
| private int _id; | |||
| [JsonIgnore] | |||
| //[JsonIgnore] | |||
| public int Id => _id; | |||
| [JsonIgnore] | |||
| //[JsonIgnore] | |||
| public Graph graph => op?.graph; | |||
| [JsonIgnore] | |||
| //[JsonIgnore] | |||
| public Operation op { get; } | |||
| [JsonIgnore] | |||
| //[JsonIgnore] | |||
| public Tensor[] outputs => op.outputs; | |||
| /// <summary> | |||
| @@ -36,4 +36,17 @@ pacman -S git patch unzip | |||
| 4. Install from local wheel file. | |||
| `pip install C:/tmp/tensorflow_pkg/tensorflow-1.13.0-cp36-cp36m-win_amd64.whl` | |||
| `pip install C:/tmp/tensorflow_pkg/tensorflow-1.13.0-cp36-cp36m-win_amd64.whl` | |||
| ### Export more APIs | |||
| Add more api to `c_api.h` | |||
| ```c++ | |||
| TF_CAPI_EXPORT extern void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); | |||
| TF_CAPI_EXPORT extern void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); | |||
| TF_CAPI_EXPORT extern void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); | |||
| ``` | |||
| @@ -15,79 +15,37 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| public void testCondTrue() | |||
| { | |||
| var graph = tf.Graph().as_default(); | |||
| // tf.train.import_meta_graph("cond_test.meta"); | |||
| var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||
| with(tf.Session(graph), sess => | |||
| { | |||
| var x = tf.constant(2, name: "x"); // graph.get_operation_by_name("Const").output; | |||
| var y = tf.constant(5, name: "y"); // graph.get_operation_by_name("Const_1").output; | |||
| var pred = tf.less(x, y); // graph.get_operation_by_name("Less").output; | |||
| Func<ITensorOrOperation> if_true = delegate | |||
| { | |||
| return tf.constant(2, name: "t2"); | |||
| }; | |||
| Func<ITensorOrOperation> if_false = delegate | |||
| { | |||
| return tf.constant(5, name: "f5"); | |||
| }; | |||
| var x = tf.constant(2, name: "x"); | |||
| var y = tf.constant(5, name: "y"); | |||
| var z = control_flow_ops.cond(pred, if_true, if_false); // graph.get_operation_by_name("cond/Merge").output | |||
| var z = control_flow_ops.cond(tf.less(x, y), | |||
| () => tf.constant(22, name: "t2"), | |||
| () => tf.constant(55, name: "f5")); | |||
| json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||
| int result = z.eval(sess); | |||
| assertEquals(result, 2); | |||
| assertEquals(result, 22); | |||
| }); | |||
| } | |||
| [TestMethod] | |||
| public void testCondFalse() | |||
| { | |||
| /* python | |||
| * import tensorflow as tf | |||
| from tensorflow.python.framework import ops | |||
| def if_true(): | |||
| return tf.math.multiply(x, 17) | |||
| def if_false(): | |||
| return tf.math.add(y, 23) | |||
| with tf.Session() as sess: | |||
| x = tf.constant(2) | |||
| y = tf.constant(1) | |||
| pred = tf.math.less(x,y) | |||
| z = tf.cond(pred, if_true, if_false) | |||
| result = z.eval() | |||
| print(result == 24) */ | |||
| var graph = tf.Graph().as_default(); | |||
| //tf.train.import_meta_graph("cond_test.meta"); | |||
| //var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||
| with(tf.Session(), sess => | |||
| with(tf.Session(graph), sess => | |||
| { | |||
| var x = tf.constant(2, name: "x"); | |||
| var y = tf.constant(1, name: "y"); | |||
| var pred = tf.less(x, y); | |||
| Func<ITensorOrOperation> if_true = delegate | |||
| { | |||
| return tf.constant(2, name: "t2"); | |||
| }; | |||
| Func<ITensorOrOperation> if_false = delegate | |||
| { | |||
| return tf.constant(1, name: "f1"); | |||
| }; | |||
| var z = control_flow_ops.cond(pred, if_true, if_false); | |||
| var z = control_flow_ops.cond(tf.less(x, y), | |||
| () => tf.constant(22, name: "t2"), | |||
| () => tf.constant(11, name: "f1")); | |||
| var json1 = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||
| int result = z.eval(sess); | |||
| assertEquals(result, 1); | |||
| assertEquals(result, 11); | |||
| }); | |||
| } | |||