| @@ -8,7 +8,6 @@ namespace Tensorflow | |||
| public class Buffer | |||
| { | |||
| private IntPtr _handle; | |||
| public IntPtr Handle => _handle; | |||
| private TF_Buffer buffer; | |||
| @@ -21,7 +20,8 @@ namespace Tensorflow | |||
| _handle = handle; | |||
| buffer = Marshal.PtrToStructure<TF_Buffer>(_handle); | |||
| Data = new byte[buffer.length]; | |||
| Marshal.Copy(buffer.data, Data, 0, (int)buffer.length); | |||
| if (buffer.length > 0) | |||
| Marshal.Copy(buffer.data, Data, 0, (int)buffer.length); | |||
| } | |||
| } | |||
| } | |||
| @@ -7,6 +7,13 @@ namespace Tensorflow | |||
| { | |||
| public static partial class c_api | |||
| { | |||
| /// <summary> | |||
| /// Useful for passing *out* a protobuf. | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_NewBuffer(); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_GetBuffer(TF_Buffer buffer); | |||
| } | |||
| @@ -64,7 +64,7 @@ namespace Tensorflow | |||
| if(obj is Tensor && allow_tensor) | |||
| { | |||
| if ((obj as Tensor).graph.Equals(this)) | |||
| if ((obj as Tensor).Graph.Equals(this)) | |||
| { | |||
| return obj; | |||
| } | |||
| @@ -6,28 +6,37 @@ namespace Tensorflow | |||
| { | |||
| public class Operation | |||
| { | |||
| public IntPtr Handle { get; } | |||
| private readonly IntPtr _handle; | |||
| private Graph _graph; | |||
| public Graph graph => _graph; | |||
| public IntPtr _c_op; | |||
| public Graph Graph { get; } | |||
| public int _id => _id_value; | |||
| private int _id_value; | |||
| public string name; | |||
| private Status status = new Status(); | |||
| public string name => c_api.TF_OperationName(_handle); | |||
| public string optype => c_api.TF_OperationOpType(_handle); | |||
| public string device => c_api.TF_OperationDevice(_handle); | |||
| public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||
| public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||
| public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); | |||
| public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||
| public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||
| public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||
| private Tensor[] _outputs; | |||
| public Tensor[] outputs => _outputs; | |||
| public Tensor[] inputs; | |||
| public Operation(IntPtr handle) | |||
| { | |||
| Handle = handle; | |||
| _handle = handle; | |||
| } | |||
| public Operation(Graph g, string opType, string oper_name) | |||
| { | |||
| _graph = g; | |||
| var status = new Status(); | |||
| Graph = g; | |||
| var desc = c_api.TF_NewOperation(g, opType, oper_name); | |||
| c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); | |||
| @@ -36,19 +45,18 @@ namespace Tensorflow | |||
| public Operation(NodeDef node_def, Graph g, List<Tensor> inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
| { | |||
| _graph = g; | |||
| Graph = g; | |||
| _id_value = _graph._next_id(); | |||
| _c_op = ops._create_c_op(g, node_def, inputs); | |||
| var num_outputs = c_api.TF_OperationNumOutputs(_c_op); | |||
| _id_value = Graph._next_id(); | |||
| _handle = ops._create_c_op(g, node_def, inputs); | |||
| _outputs = new Tensor[num_outputs]; | |||
| for (int i = 0; i < num_outputs; i++) | |||
| _outputs = new Tensor[NumOutputs]; | |||
| for (int i = 0; i < NumOutputs; i++) | |||
| { | |||
| _outputs[i] = new Tensor(this, i, output_types[i]); | |||
| } | |||
| _graph._add_op(this); | |||
| Graph._add_op(this); | |||
| } | |||
| public object get_attr(string name) | |||
| @@ -69,5 +77,15 @@ namespace Tensorflow | |||
| return ret; | |||
| } | |||
| public static implicit operator Operation(IntPtr handle) | |||
| { | |||
| return new Operation(handle); | |||
| } | |||
| public static implicit operator IntPtr(Operation op) | |||
| { | |||
| return op._handle; | |||
| } | |||
| } | |||
| } | |||
| @@ -8,6 +8,12 @@ namespace Tensorflow | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| public struct TF_Output | |||
| { | |||
| public TF_Output(IntPtr oper, int index) | |||
| { | |||
| this.oper = oper; | |||
| this.index = index; | |||
| } | |||
| public IntPtr oper; | |||
| public int index; | |||
| } | |||
| @@ -12,7 +12,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static unsafe extern IntPtr TF_GetAllOpList(); | |||
| public static extern IntPtr TF_GetAllOpList(); | |||
| /// <summary> | |||
| /// For inputs that take a single tensor. | |||
| @@ -20,24 +20,78 @@ namespace Tensorflow | |||
| /// <param name="desc"></param> | |||
| /// <param name="input"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static unsafe extern void TF_AddInput(IntPtr desc, TF_Output input); | |||
| public static extern void TF_AddInput(IntPtr desc, TF_Output input); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static unsafe extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | |||
| public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static unsafe extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||
| public static extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe int TF_OperationNumOutputs(IntPtr oper); | |||
| public static extern string TF_OperationDevice(IntPtr oper); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); | |||
| public static extern string TF_OperationName(IntPtr oper); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); | |||
| public static extern int TF_OperationNumInputs(IntPtr oper); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value); | |||
| public static extern string TF_OperationOpType(IntPtr oper); | |||
| /// <summary> | |||
| /// Get the number of control inputs to an operation. | |||
| /// </summary> | |||
| /// <param name="oper"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationNumControlInputs(IntPtr oper); | |||
| /// <summary> | |||
| /// Get the number of operations that have `*oper` as a control input. | |||
| /// </summary> | |||
| /// <param name="oper"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationNumControlOutputs(IntPtr oper); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationNumOutputs(IntPtr oper); | |||
| /// <summary> | |||
| /// Get the number of current consumers of a specific output of an | |||
| /// operation. Note that this number can change when new operations | |||
| /// are added to the graph. | |||
| /// </summary> | |||
| /// <param name="oper_out"></param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationOutputNumConsumers(TF_Output oper_out); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern TF_DataType TF_OperationOutputType(TF_Output oper_out); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern int TF_OperationOutputListLength(IntPtr oper, string arg_name, IntPtr status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); | |||
| /// <summary> | |||
| /// Set `num_dims` to -1 to represent "unknown rank". | |||
| /// </summary> | |||
| /// <param name="desc"></param> | |||
| /// <param name="attr_name"></param> | |||
| /// <param name="dims"></param> | |||
| /// <param name="num_dims"></param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetAttrShape(IntPtr desc, string attr_name, long[] dims, int num_dims); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value); | |||
| } | |||
| } | |||
| @@ -13,9 +13,9 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public class Tensor | |||
| { | |||
| public IntPtr Handle { get; } | |||
| private readonly IntPtr _handle; | |||
| public Graph graph => op.graph; | |||
| public Graph Graph => op.Graph; | |||
| public Operation op { get; } | |||
| public string name; | |||
| @@ -46,7 +46,7 @@ namespace Tensorflow | |||
| public Tensor(IntPtr handle) | |||
| { | |||
| Handle = handle; | |||
| _handle = handle; | |||
| dtype = c_api.TF_TensorType(handle); | |||
| rank = c_api.TF_NumDims(handle); | |||
| bytesize = c_api.TF_TensorByteSize(handle); | |||
| @@ -60,16 +60,16 @@ namespace Tensorflow | |||
| public Tensor(NDArray nd) | |||
| { | |||
| Handle = Allocate(nd); | |||
| dtype = c_api.TF_TensorType(Handle); | |||
| rank = c_api.TF_NumDims(Handle); | |||
| bytesize = c_api.TF_TensorByteSize(Handle); | |||
| buffer = c_api.TF_TensorData(Handle); | |||
| _handle = Allocate(nd); | |||
| dtype = c_api.TF_TensorType(_handle); | |||
| rank = c_api.TF_NumDims(_handle); | |||
| bytesize = c_api.TF_TensorByteSize(_handle); | |||
| buffer = c_api.TF_TensorData(_handle); | |||
| dataTypeSize = c_api.TF_DataTypeSize(dtype); | |||
| shape = new long[rank]; | |||
| for (int i = 0; i < rank; i++) | |||
| shape[i] = c_api.TF_Dim(Handle, i); | |||
| shape[i] = c_api.TF_Dim(_handle, i); | |||
| } | |||
| private IntPtr Allocate(NDArray nd) | |||
| @@ -117,7 +117,7 @@ namespace Tensorflow | |||
| public TF_Output _as_tf_output() | |||
| { | |||
| return c_api_util.tf_output(op._c_op, value_index); | |||
| return c_api_util.tf_output(op, value_index); | |||
| } | |||
| public T[] Data<T>() | |||
| @@ -162,7 +162,7 @@ namespace Tensorflow | |||
| public static implicit operator IntPtr(Tensor tensor) | |||
| { | |||
| return tensor.Handle; | |||
| return tensor._handle; | |||
| } | |||
| } | |||
| } | |||
| @@ -10,9 +10,26 @@ namespace TensorFlowNET.UnitTest | |||
| public class GraphTest | |||
| { | |||
| [TestMethod] | |||
| public void ConstructGraph() | |||
| public void Graph() | |||
| { | |||
| var g = tf.Graph(); | |||
| var s = new Status(); | |||
| var graph = tf.get_default_graph(); | |||
| // Make a placeholder operation. | |||
| var feed = c_test_util.Placeholder(graph, s); | |||
| Assert.AreEqual("feed", feed.name); | |||
| Assert.AreEqual("Placeholder", feed.optype); | |||
| //Assert.AreEqual("", feed.device); | |||
| Assert.AreEqual(1, feed.NumOutputs); | |||
| Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); | |||
| Assert.AreEqual(1, feed.OutputListLength); | |||
| Assert.AreEqual(0, feed.NumInputs); | |||
| Assert.AreEqual(0, feed.NumConsumers); | |||
| Assert.AreEqual(0, feed.NumControlInputs); | |||
| Assert.AreEqual(0, feed.NumControlOutputs); | |||
| var attr_value = new AttrValue(); | |||
| c_test_util.GetAttrValue(feed, "dtype", attr_value, s); | |||
| } | |||
| } | |||
| } | |||
| @@ -87,7 +87,7 @@ namespace TensorFlowNET.UnitTest | |||
| // Test for a scalar. | |||
| var three = c_test_util.ScalarConst(3, graph, s); | |||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||
| var three_out_0 = new TF_Output { oper = three.Handle }; | |||
| var three_out_0 = new TF_Output { oper = three }; | |||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | |||
| Assert.AreEqual(0, num_dims); | |||
| } | |||
| @@ -3,15 +3,42 @@ using System.Collections.Generic; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using Tensorflow; | |||
| using Buffer = Tensorflow.Buffer; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| public static class c_test_util | |||
| { | |||
| public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref IntPtr op) | |||
| public static bool GetAttrValue(Operation oper, string attr_name, AttrValue attr_value, Status s) | |||
| { | |||
| var buffer = c_api.TF_NewBuffer(); | |||
| return s.Code == TF_Code.TF_OK; | |||
| } | |||
| public static void PlaceholderHelper(Graph graph, Status s, string name, TF_DataType dtype, long[] dims, ref Operation op) | |||
| { | |||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", name); | |||
| c_api.TF_SetAttrType(desc, "dtype", dtype); | |||
| if(dims != null) | |||
| { | |||
| c_api.TF_SetAttrShape(desc, "shape", dims, dims.Length); | |||
| } | |||
| op = c_api.TF_FinishOperation(desc, s); | |||
| s.Check(); | |||
| } | |||
| public static Operation Placeholder(Graph graph, Status s, string name = "feed", TF_DataType dtype = TF_DataType.TF_INT32, long[] dims = null) | |||
| { | |||
| Operation op = null; | |||
| PlaceholderHelper(graph, s, name, dtype, dims, ref op); | |||
| return op; | |||
| } | |||
| public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref Operation op) | |||
| { | |||
| var desc = c_api.TF_NewOperation(graph, "Const", name); | |||
| c_api.TF_SetAttrTensor(desc, "value", t.Handle, s); | |||
| c_api.TF_SetAttrTensor(desc, "value", t, s); | |||
| s.Check(); | |||
| c_api.TF_SetAttrType(desc, "dtype", t.dtype); | |||
| op = c_api.TF_FinishOperation(desc, s); | |||
| @@ -24,9 +51,9 @@ namespace TensorFlowNET.UnitTest | |||
| public static Operation Const(Tensor t, Graph graph, Status s, string name) | |||
| { | |||
| IntPtr op = IntPtr.Zero; | |||
| Operation op = null; | |||
| ConstHelper(t, graph, s, name, ref op); | |||
| return new Operation(op); | |||
| return op; | |||
| } | |||
| public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") | |||