| @@ -8,7 +8,6 @@ namespace Tensorflow | |||||
| public class Buffer | public class Buffer | ||||
| { | { | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| public IntPtr Handle => _handle; | |||||
| private TF_Buffer buffer; | private TF_Buffer buffer; | ||||
| @@ -21,7 +20,8 @@ namespace Tensorflow | |||||
| _handle = handle; | _handle = handle; | ||||
| buffer = Marshal.PtrToStructure<TF_Buffer>(_handle); | buffer = Marshal.PtrToStructure<TF_Buffer>(_handle); | ||||
| Data = new byte[buffer.length]; | Data = new byte[buffer.length]; | ||||
| Marshal.Copy(buffer.data, Data, 0, (int)buffer.length); | |||||
| if (buffer.length > 0) | |||||
| Marshal.Copy(buffer.data, Data, 0, (int)buffer.length); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -7,6 +7,13 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static partial class c_api | public static partial class c_api | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Useful for passing *out* a protobuf. | |||||
| /// </summary> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_NewBuffer(); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_GetBuffer(TF_Buffer buffer); | public static extern IntPtr TF_GetBuffer(TF_Buffer buffer); | ||||
| } | } | ||||
| @@ -64,7 +64,7 @@ namespace Tensorflow | |||||
| if(obj is Tensor && allow_tensor) | if(obj is Tensor && allow_tensor) | ||||
| { | { | ||||
| if ((obj as Tensor).graph.Equals(this)) | |||||
| if ((obj as Tensor).Graph.Equals(this)) | |||||
| { | { | ||||
| return obj; | return obj; | ||||
| } | } | ||||
| @@ -6,28 +6,37 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class Operation | public class Operation | ||||
| { | { | ||||
| public IntPtr Handle { get; } | |||||
| private readonly IntPtr _handle; | |||||
| private Graph _graph; | |||||
| public Graph graph => _graph; | |||||
| public IntPtr _c_op; | |||||
| public Graph Graph { get; } | |||||
| public int _id => _id_value; | public int _id => _id_value; | ||||
| private int _id_value; | private int _id_value; | ||||
| public string name; | |||||
| private Status status = new Status(); | |||||
| public string name => c_api.TF_OperationName(_handle); | |||||
| public string optype => c_api.TF_OperationOpType(_handle); | |||||
| public string device => c_api.TF_OperationDevice(_handle); | |||||
| public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||||
| public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||||
| public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); | |||||
| public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||||
| public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||||
| public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||||
| private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
| public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
| public Tensor[] inputs; | public Tensor[] inputs; | ||||
| public Operation(IntPtr handle) | public Operation(IntPtr handle) | ||||
| { | { | ||||
| Handle = handle; | |||||
| _handle = handle; | |||||
| } | } | ||||
| public Operation(Graph g, string opType, string oper_name) | public Operation(Graph g, string opType, string oper_name) | ||||
| { | { | ||||
| _graph = g; | |||||
| var status = new Status(); | |||||
| Graph = g; | |||||
| var desc = c_api.TF_NewOperation(g, opType, oper_name); | var desc = c_api.TF_NewOperation(g, opType, oper_name); | ||||
| c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); | c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); | ||||
| @@ -36,19 +45,18 @@ namespace Tensorflow | |||||
| public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | ||||
| { | { | ||||
| _graph = g; | |||||
| Graph = g; | |||||
| _id_value = _graph._next_id(); | |||||
| _c_op = ops._create_c_op(g, node_def, inputs); | |||||
| var num_outputs = c_api.TF_OperationNumOutputs(_c_op); | |||||
| _id_value = Graph._next_id(); | |||||
| _handle = ops._create_c_op(g, node_def, inputs); | |||||
| _outputs = new Tensor[num_outputs]; | |||||
| for (int i = 0; i < num_outputs; i++) | |||||
| _outputs = new Tensor[NumOutputs]; | |||||
| for (int i = 0; i < NumOutputs; i++) | |||||
| { | { | ||||
| _outputs[i] = new Tensor(this, i, output_types[i]); | _outputs[i] = new Tensor(this, i, output_types[i]); | ||||
| } | } | ||||
| _graph._add_op(this); | |||||
| Graph._add_op(this); | |||||
| } | } | ||||
| public object get_attr(string name) | public object get_attr(string name) | ||||
| @@ -69,5 +77,15 @@ namespace Tensorflow | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| public static implicit operator Operation(IntPtr handle) | |||||
| { | |||||
| return new Operation(handle); | |||||
| } | |||||
| public static implicit operator IntPtr(Operation op) | |||||
| { | |||||
| return op._handle; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -8,6 +8,12 @@ namespace Tensorflow | |||||
| [StructLayout(LayoutKind.Sequential)] | [StructLayout(LayoutKind.Sequential)] | ||||
| public struct TF_Output | public struct TF_Output | ||||
| { | { | ||||
| public TF_Output(IntPtr oper, int index) | |||||
| { | |||||
| this.oper = oper; | |||||
| this.index = index; | |||||
| } | |||||
| public IntPtr oper; | public IntPtr oper; | ||||
| public int index; | public int index; | ||||
| } | } | ||||
| @@ -12,7 +12,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_GetAllOpList(); | |||||
| public static extern IntPtr TF_GetAllOpList(); | |||||
| /// <summary> | /// <summary> | ||||
| /// For inputs that take a single tensor. | /// For inputs that take a single tensor. | ||||
| @@ -20,24 +20,78 @@ namespace Tensorflow | |||||
| /// <param name="desc"></param> | /// <param name="desc"></param> | ||||
| /// <param name="input"></param> | /// <param name="input"></param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern void TF_AddInput(IntPtr desc, TF_Output input); | |||||
| public static extern void TF_AddInput(IntPtr desc, TF_Output input); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | |||||
| public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||||
| public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe int TF_OperationNumOutputs(IntPtr oper); | |||||
| public static extern string TF_OperationDevice(IntPtr oper); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); | |||||
| public static extern string TF_OperationName(IntPtr oper); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); | |||||
| public static extern int TF_OperationNumInputs(IntPtr oper); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value); | |||||
| public static extern string TF_OperationOpType(IntPtr oper); | |||||
| /// <summary> | |||||
| /// Get the number of control inputs to an operation. | |||||
| /// </summary> | |||||
| /// <param name="oper"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TF_OperationNumControlInputs(IntPtr oper); | |||||
| /// <summary> | |||||
| /// Get the number of operations that have `*oper` as a control input. | |||||
| /// </summary> | |||||
| /// <param name="oper"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TF_OperationNumControlOutputs(IntPtr oper); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TF_OperationNumOutputs(IntPtr oper); | |||||
| /// <summary> | |||||
| /// Get the number of current consumers of a specific output of an | |||||
| /// operation. Note that this number can change when new operations | |||||
| /// are added to the graph. | |||||
| /// </summary> | |||||
| /// <param name="oper_out"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TF_OperationOutputNumConsumers(TF_Output oper_out); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); | |||||
| /// <summary> | |||||
| /// Set `num_dims` to -1 to represent "unknown rank". | |||||
| /// </summary> | |||||
| /// <param name="desc"></param> | |||||
| /// <param name="attr_name"></param> | |||||
| /// <param name="dims"></param> | |||||
| /// <param name="num_dims"></param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_SetAttrShape(IntPtr desc, string attr_name, long[] dims, int num_dims); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value); | |||||
| } | } | ||||
| } | } | ||||
| @@ -13,9 +13,9 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class Tensor | public class Tensor | ||||
| { | { | ||||
| public IntPtr Handle { get; } | |||||
| private readonly IntPtr _handle; | |||||
| public Graph graph => op.graph; | |||||
| public Graph Graph => op.Graph; | |||||
| public Operation op { get; } | public Operation op { get; } | ||||
| public string name; | public string name; | ||||
| @@ -46,7 +46,7 @@ namespace Tensorflow | |||||
| public Tensor(IntPtr handle) | public Tensor(IntPtr handle) | ||||
| { | { | ||||
| Handle = handle; | |||||
| _handle = handle; | |||||
| dtype = c_api.TF_TensorType(handle); | dtype = c_api.TF_TensorType(handle); | ||||
| rank = c_api.TF_NumDims(handle); | rank = c_api.TF_NumDims(handle); | ||||
| bytesize = c_api.TF_TensorByteSize(handle); | bytesize = c_api.TF_TensorByteSize(handle); | ||||
| @@ -60,16 +60,16 @@ namespace Tensorflow | |||||
| public Tensor(NDArray nd) | public Tensor(NDArray nd) | ||||
| { | { | ||||
| Handle = Allocate(nd); | |||||
| dtype = c_api.TF_TensorType(Handle); | |||||
| rank = c_api.TF_NumDims(Handle); | |||||
| bytesize = c_api.TF_TensorByteSize(Handle); | |||||
| buffer = c_api.TF_TensorData(Handle); | |||||
| _handle = Allocate(nd); | |||||
| dtype = c_api.TF_TensorType(_handle); | |||||
| rank = c_api.TF_NumDims(_handle); | |||||
| bytesize = c_api.TF_TensorByteSize(_handle); | |||||
| buffer = c_api.TF_TensorData(_handle); | |||||
| dataTypeSize = c_api.TF_DataTypeSize(dtype); | dataTypeSize = c_api.TF_DataTypeSize(dtype); | ||||
| shape = new long[rank]; | shape = new long[rank]; | ||||
| for (int i = 0; i < rank; i++) | for (int i = 0; i < rank; i++) | ||||
| shape[i] = c_api.TF_Dim(Handle, i); | |||||
| shape[i] = c_api.TF_Dim(_handle, i); | |||||
| } | } | ||||
| private IntPtr Allocate(NDArray nd) | private IntPtr Allocate(NDArray nd) | ||||
| @@ -117,7 +117,7 @@ namespace Tensorflow | |||||
| public TF_Output _as_tf_output() | public TF_Output _as_tf_output() | ||||
| { | { | ||||
| return c_api_util.tf_output(op._c_op, value_index); | |||||
| return c_api_util.tf_output(op, value_index); | |||||
| } | } | ||||
| public T[] Data<T>() | public T[] Data<T>() | ||||
| @@ -162,7 +162,7 @@ namespace Tensorflow | |||||
| public static implicit operator IntPtr(Tensor tensor) | public static implicit operator IntPtr(Tensor tensor) | ||||
| { | { | ||||
| return tensor.Handle; | |||||
| return tensor._handle; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,9 +10,26 @@ namespace TensorFlowNET.UnitTest | |||||
| public class GraphTest | public class GraphTest | ||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void ConstructGraph() | |||||
| public void Graph() | |||||
| { | { | ||||
| var g = tf.Graph(); | |||||
| var s = new Status(); | |||||
| var graph = tf.get_default_graph(); | |||||
| // Make a placeholder operation. | |||||
| var feed = c_test_util.Placeholder(graph, s); | |||||
| Assert.AreEqual("feed", feed.name); | |||||
| Assert.AreEqual("Placeholder", feed.optype); | |||||
| //Assert.AreEqual("", feed.device); | |||||
| Assert.AreEqual(1, feed.NumOutputs); | |||||
| Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); | |||||
| Assert.AreEqual(1, feed.OutputListLength); | |||||
| Assert.AreEqual(0, feed.NumInputs); | |||||
| Assert.AreEqual(0, feed.NumConsumers); | |||||
| Assert.AreEqual(0, feed.NumControlInputs); | |||||
| Assert.AreEqual(0, feed.NumControlOutputs); | |||||
| var attr_value = new AttrValue(); | |||||
| c_test_util.GetAttrValue(feed, "dtype", attr_value, s); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -87,7 +87,7 @@ namespace TensorFlowNET.UnitTest | |||||
| // Test for a scalar. | // Test for a scalar. | ||||
| var three = c_test_util.ScalarConst(3, graph, s); | var three = c_test_util.ScalarConst(3, graph, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| var three_out_0 = new TF_Output { oper = three.Handle }; | |||||
| var three_out_0 = new TF_Output { oper = three }; | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | ||||
| Assert.AreEqual(0, num_dims); | Assert.AreEqual(0, num_dims); | ||||
| } | } | ||||
| @@ -3,15 +3,42 @@ using System.Collections.Generic; | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Buffer = Tensorflow.Buffer; | |||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| public static class c_test_util | public static class c_test_util | ||||
| { | { | ||||
| public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref IntPtr op) | |||||
| public static bool GetAttrValue(Operation oper, string attr_name, AttrValue attr_value, Status s) | |||||
| { | |||||
| var buffer = c_api.TF_NewBuffer(); | |||||
| return s.Code == TF_Code.TF_OK; | |||||
| } | |||||
| public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | |||||
| { | |||||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||||
| c_api.TF_SetAttrType(desc, "dtype", dtype); | |||||
| if(dims != null) | |||||
| { | |||||
| c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||||
| } | |||||
| op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| } | |||||
| public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | |||||
| { | |||||
| Operation op = null; | |||||
| PlaceholderHelper(graph, s, name, dtype, dims, ref op); | |||||
| return op; | |||||
| } | |||||
| public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref Operation op) | |||||
| { | { | ||||
| var desc = c_api.TF_NewOperation(graph, "Const", name); | var desc = c_api.TF_NewOperation(graph, "Const", name); | ||||
| c_api.TF_SetAttrTensor(desc, "value", t.Handle, s); | |||||
| c_api.TF_SetAttrTensor(desc, "value", t, s); | |||||
| s.Check(); | s.Check(); | ||||
| c_api.TF_SetAttrType(desc, "dtype", t.dtype); | c_api.TF_SetAttrType(desc, "dtype", t.dtype); | ||||
| op = c_api.TF_FinishOperation(desc, s); | op = c_api.TF_FinishOperation(desc, s); | ||||
| @@ -24,9 +51,9 @@ namespace TensorFlowNET.UnitTest | |||||
| public static Operation Const(Tensor t, Graph graph, Status s, string name) | public static Operation Const(Tensor t, Graph graph, Status s, string name) | ||||
| { | { | ||||
| IntPtr op = IntPtr.Zero; | |||||
| Operation op = null; | |||||
| ConstHelper(t, graph, s, name, ref op); | ConstHelper(t, graph, s, name, ref op); | ||||
| return new Operation(op); | |||||
| return op; | |||||
| } | } | ||||
| public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") | public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") | ||||