| @@ -63,6 +63,7 @@ namespace Tensorflow | |||||
| _id_value = Graph._next_id(); | _id_value = Graph._next_id(); | ||||
| _handle = ops._create_c_op(g, node_def, inputs); | _handle = ops._create_c_op(g, node_def, inputs); | ||||
| NumOutputs = c_api.TF_OperationNumOutputs(_handle); | |||||
| _outputs = new Tensor[NumOutputs]; | _outputs = new Tensor[NumOutputs]; | ||||
| for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
| @@ -90,6 +90,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_OperationToNodeDef(IntPtr oper, IntPtr buffer, IntPtr status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); | public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); | ||||
| @@ -6,6 +6,7 @@ using System.Threading; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using node_def_pb2 = Tensorflow; | using node_def_pb2 = Tensorflow; | ||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using System.Linq; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -27,12 +28,14 @@ namespace Tensorflow | |||||
| var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name); | var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name); | ||||
| // Add inputs | // Add inputs | ||||
| if(inputs != null) | |||||
| if(inputs != null && inputs.Count > 0) | |||||
| { | { | ||||
| foreach (var op_input in inputs) | |||||
| /*foreach (var op_input in inputs) | |||||
| { | { | ||||
| c_api.TF_AddInput(op_desc, op_input._as_tf_output()); | c_api.TF_AddInput(op_desc, op_input._as_tf_output()); | ||||
| } | |||||
| }*/ | |||||
| c_api.TF_AddInputList(op_desc, inputs.Select(x => x._as_tf_output()).ToArray(), inputs.Count); | |||||
| } | } | ||||
| var status = new Status(); | var status = new Status(); | ||||
| @@ -40,7 +40,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| if(Code != TF_Code.TF_OK) | if(Code != TF_Code.TF_OK) | ||||
| { | { | ||||
| throw new Exception(Message); | |||||
| Console.WriteLine(Message); | |||||
| // throw new Exception(Message); | |||||
| } | } | ||||
| } | } | ||||
| @@ -118,7 +118,7 @@ namespace Tensorflow | |||||
| public TF_Output _as_tf_output() | public TF_Output _as_tf_output() | ||||
| { | { | ||||
| return c_api_util.tf_output(op, value_index); | |||||
| return new TF_Output(op, value_index); | |||||
| } | } | ||||
| public T[] Data<T>() | public T[] Data<T>() | ||||
| @@ -8,12 +8,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static TF_Output tf_output(IntPtr c_op, int index) | public static TF_Output tf_output(IntPtr c_op, int index) | ||||
| { | { | ||||
| var ret = new TF_Output(); | |||||
| ret.oper = c_op; | |||||
| ret.index = index; | |||||
| return ret; | |||||
| return new TF_Output(c_op, index); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -12,7 +12,16 @@ namespace TensorFlowNET.Examples | |||||
| foreach(Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample)))) | foreach(Type type in assembly.GetTypes().Where(x => x.GetInterfaces().Contains(typeof(IExample)))) | ||||
| { | { | ||||
| var example = (IExample)Activator.CreateInstance(type); | var example = (IExample)Activator.CreateInstance(type); | ||||
| example.Run(); | |||||
| try | |||||
| { | |||||
| example.Run(); | |||||
| } | |||||
| catch (Exception ex) | |||||
| { | |||||
| Console.WriteLine(ex); | |||||
| Console.ReadLine(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,6 +10,7 @@ | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -33,20 +33,23 @@ namespace TensorFlowNET.UnitTest | |||||
| Assert.AreEqual(0, feed.NumControlOutputs); | Assert.AreEqual(0, feed.NumControlOutputs); | ||||
| AttrValue attr_value = null; | AttrValue attr_value = null; | ||||
| c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s); | |||||
| Assert.IsTrue(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s)); | |||||
| Assert.AreEqual(attr_value.Type, DataType.DtInt32); | Assert.AreEqual(attr_value.Type, DataType.DtInt32); | ||||
| // Test not found errors in TF_Operation*() query functions. | // Test not found errors in TF_Operation*() query functions. | ||||
| // Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||||
| // Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||||
| // Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||||
| // Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); | |||||
| Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||||
| Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||||
| //Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||||
| //Assert.AreEqual("Operation '' has no attr named 'missing'.", s.Message); | |||||
| // Make a constant oper with the scalar "3". | // Make a constant oper with the scalar "3". | ||||
| var three = c_test_util.ScalarConst(3, graph, s); | var three = c_test_util.ScalarConst(3, graph, s); | ||||
| // Add oper. | // Add oper. | ||||
| var add = c_test_util.Add(feed, three, graph, s); | var add = c_test_util.Add(feed, three, graph, s); | ||||
| NodeDef node_def = null; | |||||
| c_test_util.GetNodeDef(feed, ref node_def); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -43,7 +43,7 @@ namespace TensorFlowNET.UnitTest | |||||
| public void addInConstant() | public void addInConstant() | ||||
| { | { | ||||
| var a = tf.constant(4.0f); | var a = tf.constant(4.0f); | ||||
| var b = tf.constant(5.0f); | |||||
| var b = tf.placeholder(tf.float32); | |||||
| var c = tf.add(a, b); | var c = tf.add(a, b); | ||||
| using (var sess = tf.Session()) | using (var sess = tf.Session()) | ||||
| @@ -139,7 +139,7 @@ namespace TensorFlowNET.UnitTest | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); | ||||
| //Assert.IsTrue(s.Code == TF_Code.TF_OK); | //Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| graph.Dispose(); | |||||
| // graph.Dispose(); | |||||
| s.Dispose(); | s.Dispose(); | ||||
| } | } | ||||
| } | } | ||||
| @@ -42,6 +42,15 @@ namespace TensorFlowNET.UnitTest | |||||
| return s.Code == TF_Code.TF_OK; | return s.Code == TF_Code.TF_OK; | ||||
| } | } | ||||
| public static bool GetNodeDef(Operation oper, ref NodeDef node_def) | |||||
| { | |||||
| var s = new Status(); | |||||
| var buffer = new Buffer(); | |||||
| c_api.TF_OperationToNodeDef(oper, buffer, s); | |||||
| return s.Code == TF_Code.TF_OK; | |||||
| } | |||||
| public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | ||||
| { | { | ||||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | ||||
| @@ -69,10 +78,6 @@ namespace TensorFlowNET.UnitTest | |||||
| c_api.TF_SetAttrType(desc, "dtype", t.dtype); | c_api.TF_SetAttrType(desc, "dtype", t.dtype); | ||||
| op = c_api.TF_FinishOperation(desc, s); | op = c_api.TF_FinishOperation(desc, s); | ||||
| s.Check(); | s.Check(); | ||||
| if(op == null) | |||||
| { | |||||
| throw new Exception("c_api.TF_FinishOperation failed."); | |||||
| } | |||||
| } | } | ||||
| public static Operation Const(Tensor t, Graph graph, Status s, string name) | public static Operation Const(Tensor t, Graph graph, Status s, string name) | ||||