| @@ -1,5 +1,4 @@ | |||||
| //using Newtonsoft.Json; | |||||
| using System; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| @@ -15,10 +14,11 @@ namespace Tensorflow | |||||
| private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
| public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
| //[JsonIgnore] | |||||
| public Tensor output => _outputs.FirstOrDefault(); | public Tensor output => _outputs.FirstOrDefault(); | ||||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | ||||
| public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); | ||||
| public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) | ||||
| @@ -1,5 +1,6 @@ | |||||
| using Google.Protobuf.Collections; | using Google.Protobuf.Collections; | ||||
| //using Newtonsoft.Json; | |||||
| using Newtonsoft.Json; | |||||
| //using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| @@ -33,15 +34,15 @@ namespace Tensorflow | |||||
| private readonly IntPtr _operDesc; | private readonly IntPtr _operDesc; | ||||
| private Graph _graph; | private Graph _graph; | ||||
| //[JsonIgnore] | |||||
| [JsonIgnore] | |||||
| public Graph graph => _graph; | public Graph graph => _graph; | ||||
| //[JsonIgnore] | |||||
| [JsonIgnore] | |||||
| public int _id => _id_value; | public int _id => _id_value; | ||||
| //[JsonIgnore] | |||||
| [JsonIgnore] | |||||
| public int _id_value; | public int _id_value; | ||||
| public string type => OpType; | public string type => OpType; | ||||
| //[JsonIgnore] | |||||
| [JsonIgnore] | |||||
| public Operation op => this; | public Operation op => this; | ||||
| public TF_DataType dtype => TF_DataType.DtInvalid; | public TF_DataType dtype => TF_DataType.DtInvalid; | ||||
| private Status status = new Status(); | private Status status = new Status(); | ||||
| @@ -45,6 +45,7 @@ Bug memory leak issue when allocating Tensor.</PackageReleaseNotes> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="Google.Protobuf" Version="3.7.0" /> | <PackageReference Include="Google.Protobuf" Version="3.7.0" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="12.0.1" /> | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| @@ -1,4 +1,5 @@ | |||||
| //using Newtonsoft.Json; | //using Newtonsoft.Json; | ||||
| using Newtonsoft.Json; | |||||
| using NumSharp; | using NumSharp; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| @@ -18,13 +19,13 @@ namespace Tensorflow | |||||
| private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
| private int _id; | private int _id; | ||||
| //[JsonIgnore] | |||||
| [JsonIgnore] | |||||
| public int Id => _id; | public int Id => _id; | ||||
| //[JsonIgnore] | |||||
| [JsonIgnore] | |||||
| public Graph graph => op?.graph; | public Graph graph => op?.graph; | ||||
| //[JsonIgnore] | |||||
| [JsonIgnore] | |||||
| public Operation op { get; } | public Operation op { get; } | ||||
| //[JsonIgnore] | |||||
| [JsonIgnore] | |||||
| public Tensor[] outputs => op.outputs; | public Tensor[] outputs => op.outputs; | ||||
| /// <summary> | /// <summary> | ||||
| @@ -112,9 +113,6 @@ namespace Tensorflow | |||||
| public int NDims => rank; | public int NDims => rank; | ||||
| //[JsonIgnore] | |||||
| public Operation[] Consumers => consumers(); | |||||
| public string Device => op.Device; | public string Device => op.Device; | ||||
| public Operation[] consumers() | public Operation[] consumers() | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| @@ -14,26 +15,30 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
| public void testCondTrue() | public void testCondTrue() | ||||
| { | { | ||||
| var graph = tf.Graph().as_default(); | var graph = tf.Graph().as_default(); | ||||
| // tf.train.import_meta_graph("cond_test.meta"); | |||||
| var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||||
| with(tf.Session(graph), sess => | with(tf.Session(graph), sess => | ||||
| { | { | ||||
| var x = tf.constant(2); | |||||
| var y = tf.constant(5); | |||||
| var pred = tf.less(x, y); | |||||
| var x = tf.constant(2, name: "x"); // graph.get_operation_by_name("Const").output; | |||||
| var y = tf.constant(5, name: "y"); // graph.get_operation_by_name("Const_1").output; | |||||
| var pred = tf.less(x, y); // graph.get_operation_by_name("Less").output; | |||||
| Func<ITensorOrOperation> if_true = delegate | Func<ITensorOrOperation> if_true = delegate | ||||
| { | { | ||||
| return tf.multiply(x, 17); | |||||
| return tf.constant(2, name: "t2"); | |||||
| }; | }; | ||||
| Func<ITensorOrOperation> if_false = delegate | Func<ITensorOrOperation> if_false = delegate | ||||
| { | { | ||||
| return tf.add(y, 23); | |||||
| return tf.constant(5, name: "f5"); | |||||
| }; | }; | ||||
| var z = control_flow_ops.cond(pred, if_true, if_false); | |||||
| var z = control_flow_ops.cond(pred, if_true, if_false); // graph.get_operation_by_name("cond/Merge").output | |||||
| json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||||
| int result = z.eval(sess); | int result = z.eval(sess); | ||||
| assertEquals(result, 34); | |||||
| assertEquals(result, 2); | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -58,25 +63,31 @@ namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
| print(result == 24) */ | print(result == 24) */ | ||||
| var graph = tf.Graph().as_default(); | |||||
| //tf.train.import_meta_graph("cond_test.meta"); | |||||
| //var json = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||||
| with(tf.Session(), sess => | with(tf.Session(), sess => | ||||
| { | { | ||||
| var x = tf.constant(2); | |||||
| var y = tf.constant(1); | |||||
| var x = tf.constant(2, name: "x"); | |||||
| var y = tf.constant(1, name: "y"); | |||||
| var pred = tf.less(x, y); | var pred = tf.less(x, y); | ||||
| Func<ITensorOrOperation> if_true = delegate | Func<ITensorOrOperation> if_true = delegate | ||||
| { | { | ||||
| return tf.multiply(x, 17); | |||||
| return tf.constant(2, name: "t2"); | |||||
| }; | }; | ||||
| Func<ITensorOrOperation> if_false = delegate | Func<ITensorOrOperation> if_false = delegate | ||||
| { | { | ||||
| return tf.add(y, 23); | |||||
| return tf.constant(1, name: "f1"); | |||||
| }; | }; | ||||
| var z = control_flow_ops.cond(pred, if_true, if_false); | var z = control_flow_ops.cond(pred, if_true, if_false); | ||||
| var json1 = JsonConvert.SerializeObject(graph._nodes_by_name, Formatting.Indented); | |||||
| int result = z.eval(sess); | int result = z.eval(sess); | ||||
| assertEquals(result, 24); | |||||
| assertEquals(result, 1); | |||||
| }); | }); | ||||
| } | } | ||||