diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index f4516674..7e387522 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -8,7 +8,6 @@ namespace Tensorflow public class Buffer { private IntPtr _handle; - public IntPtr Handle => _handle; private TF_Buffer buffer; @@ -21,7 +20,8 @@ namespace Tensorflow _handle = handle; buffer = Marshal.PtrToStructure(_handle); Data = new byte[buffer.length]; - Marshal.Copy(buffer.data, Data, 0, (int)buffer.length); + if (buffer.length > 0) + Marshal.Copy(buffer.data, Data, 0, (int)buffer.length); } } } diff --git a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs index 451b43b6..86857392 100644 --- a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs @@ -7,6 +7,13 @@ namespace Tensorflow { public static partial class c_api { + /// + /// Useful for passing *out* a protobuf. + /// + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_NewBuffer(); + [DllImport(TensorFlowLibName)] public static extern IntPtr TF_GetBuffer(TF_Buffer buffer); } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 5acfc117..a27d031e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -64,7 +64,7 @@ namespace Tensorflow if(obj is Tensor && allow_tensor) { - if ((obj as Tensor).graph.Equals(this)) + if ((obj as Tensor).Graph.Equals(this)) { return obj; } diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index fb12dae5..f89a16bc 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -6,28 +6,37 @@ namespace Tensorflow { public class Operation { - public IntPtr Handle { get; } + private readonly IntPtr _handle; - private Graph _graph; - public Graph graph => _graph; - public IntPtr _c_op; + public Graph Graph { get; } public int _id => _id_value; private int _id_value; - public string name; + + private Status status = new Status(); + + public string name => c_api.TF_OperationName(_handle); + public string optype => c_api.TF_OperationOpType(_handle); + public string device => c_api.TF_OperationDevice(_handle); + public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); + public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); + public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); + public int NumInputs => c_api.TF_OperationNumInputs(_handle); + public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); + public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); + public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); + private Tensor[] _outputs; public Tensor[] outputs => _outputs; public Tensor[] inputs; public Operation(IntPtr handle) { - Handle = handle; + _handle = handle; } public Operation(Graph g, string opType, string oper_name) { - _graph = g; - - var status = new Status(); + Graph = g; var desc = c_api.TF_NewOperation(g, opType, oper_name); c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); @@ -36,19 +45,18 @@ namespace Tensorflow public Operation(NodeDef node_def, Graph g, List inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) { - _graph = g; + Graph = g; - _id_value = _graph._next_id(); - _c_op = ops._create_c_op(g, node_def, inputs); - var num_outputs = c_api.TF_OperationNumOutputs(_c_op); + _id_value = Graph._next_id(); + _handle = ops._create_c_op(g, node_def, inputs); - _outputs = new Tensor[num_outputs]; - for (int i = 0; i < num_outputs; i++) + _outputs = new Tensor[NumOutputs]; + for (int i = 0; i < NumOutputs; i++) { _outputs[i] = new Tensor(this, i, output_types[i]); } - _graph._add_op(this); + Graph._add_op(this); } public object get_attr(string name) @@ -69,5 +77,15 @@ namespace Tensorflow return ret; } + + public static implicit operator Operation(IntPtr handle) + { + return new Operation(handle); + } + + public static implicit operator IntPtr(Operation op) + { + return op._handle; + } } } diff --git a/src/TensorFlowNET.Core/Operations/TF_Output.cs b/src/TensorFlowNET.Core/Operations/TF_Output.cs index 98ec3d17..16d0285a 100644 --- a/src/TensorFlowNET.Core/Operations/TF_Output.cs +++ b/src/TensorFlowNET.Core/Operations/TF_Output.cs @@ -8,6 +8,12 @@ namespace Tensorflow [StructLayout(LayoutKind.Sequential)] public struct TF_Output { + public TF_Output(IntPtr oper, int index) + { + this.oper = oper; + this.index = index; + } + public IntPtr oper; public int index; } diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 0fafaa10..39b82b13 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -12,7 +12,7 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_GetAllOpList(); + public static extern IntPtr TF_GetAllOpList(); /// /// For inputs that take a single tensor. @@ -20,24 +20,78 @@ namespace Tensorflow /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern void TF_AddInput(IntPtr desc, TF_Output input); + public static extern void TF_AddInput(IntPtr desc, TF_Output input); [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); + public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); + public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); [DllImport(TensorFlowLibName)] - public static extern unsafe int TF_OperationNumOutputs(IntPtr oper); + public static extern string TF_OperationDevice(IntPtr oper); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); + public static extern string TF_OperationName(IntPtr oper); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); + public static extern int TF_OperationNumInputs(IntPtr oper); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value); + public static extern string TF_OperationOpType(IntPtr oper); + + /// + /// Get the number of control inputs to an operation. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationNumControlInputs(IntPtr oper); + + /// + /// Get the number of operations that have `*oper` as a control input. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationNumControlOutputs(IntPtr oper); + + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationNumOutputs(IntPtr oper); + + /// + /// Get the number of current consumers of a specific output of an + /// operation. Note that this number can change when new operations + /// are added to the graph. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationOutputNumConsumers(TF_Output oper_out); + + [DllImport(TensorFlowLibName)] + public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); + + [DllImport(TensorFlowLibName)] + public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); + + /// + /// Set `num_dims` to -1 to represent "unknown rank". + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrShape(IntPtr desc, string attr_name, long[] dims, int num_dims); + + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value); } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 6e27b305..56748f3a 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -13,9 +13,9 @@ namespace Tensorflow /// public class Tensor { - public IntPtr Handle { get; } + private readonly IntPtr _handle; - public Graph graph => op.graph; + public Graph Graph => op.Graph; public Operation op { get; } public string name; @@ -46,7 +46,7 @@ namespace Tensorflow public Tensor(IntPtr handle) { - Handle = handle; + _handle = handle; dtype = c_api.TF_TensorType(handle); rank = c_api.TF_NumDims(handle); bytesize = c_api.TF_TensorByteSize(handle); @@ -60,16 +60,16 @@ namespace Tensorflow public Tensor(NDArray nd) { - Handle = Allocate(nd); - dtype = c_api.TF_TensorType(Handle); - rank = c_api.TF_NumDims(Handle); - bytesize = c_api.TF_TensorByteSize(Handle); - buffer = c_api.TF_TensorData(Handle); + _handle = Allocate(nd); + dtype = c_api.TF_TensorType(_handle); + rank = c_api.TF_NumDims(_handle); + bytesize = c_api.TF_TensorByteSize(_handle); + buffer = c_api.TF_TensorData(_handle); dataTypeSize = c_api.TF_DataTypeSize(dtype); shape = new long[rank]; for (int i = 0; i < rank; i++) - shape[i] = c_api.TF_Dim(Handle, i); + shape[i] = c_api.TF_Dim(_handle, i); } private IntPtr Allocate(NDArray nd) @@ -117,7 +117,7 @@ namespace Tensorflow public TF_Output _as_tf_output() { - return c_api_util.tf_output(op._c_op, value_index); + return c_api_util.tf_output(op, value_index); } public T[] Data() @@ -162,7 +162,7 @@ namespace Tensorflow public static implicit operator IntPtr(Tensor tensor) { - return tensor.Handle; + return tensor._handle; } } } diff --git a/test/TensorFlowNET.UnitTest/GraphTest.cs b/test/TensorFlowNET.UnitTest/GraphTest.cs index 1864fede..1cb1c7e7 100644 --- a/test/TensorFlowNET.UnitTest/GraphTest.cs +++ b/test/TensorFlowNET.UnitTest/GraphTest.cs @@ -10,9 +10,26 @@ namespace TensorFlowNET.UnitTest public class GraphTest { [TestMethod] - public void ConstructGraph() + public void Graph() { - var g = tf.Graph(); + var s = new Status(); + var graph = tf.get_default_graph(); + + // Make a placeholder operation. + var feed = c_test_util.Placeholder(graph, s); + Assert.AreEqual("feed", feed.name); + Assert.AreEqual("Placeholder", feed.optype); + //Assert.AreEqual("", feed.device); + Assert.AreEqual(1, feed.NumOutputs); + Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); + Assert.AreEqual(1, feed.OutputListLength); + Assert.AreEqual(0, feed.NumInputs); + Assert.AreEqual(0, feed.NumConsumers); + Assert.AreEqual(0, feed.NumControlInputs); + Assert.AreEqual(0, feed.NumControlOutputs); + + var attr_value = new AttrValue(); + c_test_util.GetAttrValue(feed, "dtype", attr_value, s); } } } diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index bb7c8f27..7b083b16 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -87,7 +87,7 @@ namespace TensorFlowNET.UnitTest // Test for a scalar. var three = c_test_util.ScalarConst(3, graph, s); Assert.IsTrue(s.Code == TF_Code.TF_OK); - var three_out_0 = new TF_Output { oper = three.Handle }; + var three_out_0 = new TF_Output { oper = three }; num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); Assert.AreEqual(0, num_dims); } diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs index add1913c..015081d8 100644 --- a/test/TensorFlowNET.UnitTest/c_test_util.cs +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -3,15 +3,42 @@ using System.Collections.Generic; using System.Runtime.InteropServices; using System.Text; using Tensorflow; +using Buffer = Tensorflow.Buffer; namespace TensorFlowNET.UnitTest { public static class c_test_util { - public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref IntPtr op) + public static bool GetAttrValue(Operation oper, string attr_name, AttrValue attr_value, Status s) + { + var buffer = c_api.TF_NewBuffer(); + + return s.Code == TF_Code.TF_OK; + } + + public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) + { + var desc = c_api.TF_NewOperation(graph, "Placeholder", name); + c_api.TF_SetAttrType(desc, "dtype", dtype); + if(dims != null) + { + c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); + } + op = c_api.TF_FinishOperation(desc, s); + s.Check(); + } + + public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) + { + Operation op = null; + PlaceholderHelper(graph, s, name, dtype, dims, ref op); + return op; + } + + public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref Operation op) { var desc = c_api.TF_NewOperation(graph, "Const", name); - c_api.TF_SetAttrTensor(desc, "value", t.Handle, s); + c_api.TF_SetAttrTensor(desc, "value", t, s); s.Check(); c_api.TF_SetAttrType(desc, "dtype", t.dtype); op = c_api.TF_FinishOperation(desc, s); @@ -24,9 +51,9 @@ namespace TensorFlowNET.UnitTest public static Operation Const(Tensor t, Graph graph, Status s, string name) { - IntPtr op = IntPtr.Zero; + Operation op = null; ConstHelper(t, graph, s, name, ref op); - return new Operation(op); + return op; } public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const")