diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index bf2799cb..f4516674 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -9,21 +9,19 @@ namespace Tensorflow { private IntPtr _handle; public IntPtr Handle => _handle; - //public TF_Buffer buffer => Marshal.PtrToStructure(_handle); - public unsafe Buffer() - { - _handle = Marshal.AllocHGlobal(sizeof(TF_Buffer)); - } + private TF_Buffer buffer; - public byte[] GetBuffer() - { - var buffer = Marshal.PtrToStructure(_handle); + public byte[] Data; - var data = Marshal.AllocHGlobal(buffer.length); - //var bytes = c_api.TF_GetBuffer(buffer.data); + public int Length => (int)buffer.length; - return null; + public unsafe Buffer(IntPtr handle) + { + _handle = handle; + buffer = Marshal.PtrToStructure(_handle); + Data = new byte[buffer.length]; + Marshal.Copy(buffer.data, Data, 0, (int)buffer.length); } } } diff --git a/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs b/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs index 90fc98db..3c3ac91e 100644 --- a/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/TF_Buffer.cs @@ -9,7 +9,7 @@ namespace Tensorflow public struct TF_Buffer { public IntPtr data; - public int length; + public ulong length; public IntPtr data_deallocator; } } diff --git a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs index 0cc081c2..451b43b6 100644 --- a/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/c_api.buffer.cs @@ -8,6 +8,6 @@ namespace Tensorflow public static partial class c_api { [DllImport(TensorFlowLibName)] - public static extern string TF_GetBuffer(IntPtr buffer); + public static extern IntPtr TF_GetBuffer(TF_Buffer buffer); } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 2109c964..5acfc117 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -15,8 +15,7 @@ namespace Tensorflow /// public class Graph { - private IntPtr _c_graph; - public IntPtr Handle => _c_graph; + private IntPtr _handle; private Dictionary _nodes_by_id; private Dictionary _nodes_by_name; private Dictionary _names_in_use; @@ -28,7 +27,7 @@ namespace Tensorflow public Graph(IntPtr graph) { - this._c_graph = graph; + _handle = graph; _nodes_by_id = new Dictionary(); _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); @@ -171,5 +170,10 @@ namespace Tensorflow { return _nodes_by_name.Values.Select(x => x).ToArray(); } + + public static implicit operator IntPtr(Graph graph) + { + return graph._handle; + } } } diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 337fdea9..21900f15 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -10,6 +10,39 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); + /// + /// Returns the shape of the Tensor referenced by `output` in `graph` + /// into `dims`. `dims` must be an array large enough to hold `num_dims` + /// entries (e.g., the return value of TF_GraphGetTensorNumDims). + /// + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); + + /// + /// Sets the shape of the Tensor referenced by `output` in `graph` to + /// the shape described by `dims` and `num_dims`. + /// + [DllImport(TensorFlowLibName)] + public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); + + /// + /// Returns the number of dimensions of the Tensor referenced by `output` + /// in `graph`. + /// + /// If the number of dimensions in the shape is unknown, returns -1. + /// + /// + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern int TF_GraphGetTensorNumDims(IntPtr graph, TF_Output output, IntPtr status); + [DllImport(TensorFlowLibName)] public static unsafe extern IntPtr TF_NewGraph(); } diff --git a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs index 27ce5910..1ae4e951 100644 --- a/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs +++ b/src/TensorFlowNET.Core/Operations/OpDefLibrary.cs @@ -28,9 +28,6 @@ namespace Tensorflow { var op_def = _ops[op_type_name]; - var status = new Status(); - var buffer = new Buffer(); - var g = ops.get_default_graph(); if (String.IsNullOrEmpty(name)) diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs index d7908683..fb12dae5 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.cs @@ -1,12 +1,13 @@ using System; using System.Collections.Generic; using System.Text; -using TF_DataType = Tensorflow.DataType; namespace Tensorflow { public class Operation { + public IntPtr Handle { get; } + private Graph _graph; public Graph graph => _graph; public IntPtr _c_op; @@ -17,15 +18,20 @@ namespace Tensorflow public Tensor[] outputs => _outputs; public Tensor[] inputs; + public Operation(IntPtr handle) + { + Handle = handle; + } + public Operation(Graph g, string opType, string oper_name) { _graph = g; var status = new Status(); - var desc = c_api.TF_NewOperation(g.Handle, opType, oper_name); + var desc = c_api.TF_NewOperation(g, opType, oper_name); c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); - c_api.TF_FinishOperation(desc, status.Handle); + c_api.TF_FinishOperation(desc, status); } public Operation(NodeDef node_def, Graph g, List inputs = null, TF_DataType[] output_types = null, object control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) diff --git a/src/TensorFlowNET.Core/Graphs/TF_Input.cs b/src/TensorFlowNET.Core/Operations/TF_Input.cs similarity index 100% rename from src/TensorFlowNET.Core/Graphs/TF_Input.cs rename to src/TensorFlowNET.Core/Operations/TF_Input.cs diff --git a/src/TensorFlowNET.Core/Graphs/TF_Output.cs b/src/TensorFlowNET.Core/Operations/TF_Output.cs similarity index 100% rename from src/TensorFlowNET.Core/Graphs/TF_Output.cs rename to src/TensorFlowNET.Core/Operations/TF_Output.cs diff --git a/src/TensorFlowNET.Core/Operations/c_api.ops.cs b/src/TensorFlowNET.Core/Operations/c_api.ops.cs index 0422aa75..0fafaa10 100644 --- a/src/TensorFlowNET.Core/Operations/c_api.ops.cs +++ b/src/TensorFlowNET.Core/Operations/c_api.ops.cs @@ -7,30 +7,37 @@ namespace Tensorflow { public static partial class c_api { + /// + /// Get the OpList of all OpDefs defined in this address space. + /// + /// + [DllImport(TensorFlowLibName)] + public static unsafe extern IntPtr TF_GetAllOpList(); + /// /// For inputs that take a single tensor. /// /// /// [DllImport(TensorFlowLibName)] - public static unsafe extern void TF_AddInput(TF_OperationDescription desc, TF_Output input); + public static unsafe extern void TF_AddInput(IntPtr desc, TF_Output input); [DllImport(TensorFlowLibName)] - public static unsafe extern IntPtr TF_FinishOperation(TF_OperationDescription desc, IntPtr status); + public static unsafe extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); [DllImport(TensorFlowLibName)] - public static unsafe extern TF_OperationDescription TF_NewOperation(IntPtr graph, string opType, string oper_name); + public static unsafe extern IntPtr TF_NewOperation(IntPtr graph, string opType, string oper_name); [DllImport(TensorFlowLibName)] public static extern unsafe int TF_OperationNumOutputs(IntPtr oper); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetAttrValueProto(TF_OperationDescription desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); + public static extern unsafe void TF_SetAttrValueProto(IntPtr desc, string attr_name, IntPtr proto, UIntPtr proto_len, IntPtr status); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetAttrTensor(TF_OperationDescription desc, string attr_name, IntPtr value, IntPtr status); + public static extern unsafe void TF_SetAttrTensor(IntPtr desc, string attr_name, IntPtr value, IntPtr status); [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetAttrType(TF_OperationDescription desc, string attr_name, TF_DataType value); + public static extern unsafe void TF_SetAttrType(IntPtr desc, string attr_name, TF_DataType value); } } diff --git a/src/TensorFlowNET.Core/Operations/ops.cs b/src/TensorFlowNET.Core/Operations/ops.cs index 7b6dfcf7..28449fc2 100644 --- a/src/TensorFlowNET.Core/Operations/ops.cs +++ b/src/TensorFlowNET.Core/Operations/ops.cs @@ -24,7 +24,7 @@ namespace Tensorflow public static unsafe IntPtr _create_c_op(Graph graph, NodeDef node_def, List inputs) { - var op_desc = c_api.TF_NewOperation(graph.Handle, node_def.Op, node_def.Name); + var op_desc = c_api.TF_NewOperation(graph, node_def.Op, node_def.Name); // Add inputs if(inputs != null) @@ -45,12 +45,12 @@ namespace Tensorflow var bytes = attr.Value.ToByteArray(); var proto = Marshal.AllocHGlobal(bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length); - c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status.Handle); + c_api.TF_SetAttrValueProto(op_desc, attr.Key, proto, proto_len: (UIntPtr)bytes.Length, status: status); if(status.Code != TF_Code.TF_OK) throw new Exception(status.Message); } - var c_op = c_api.TF_FinishOperation(op_desc, status.Handle); + var c_op = c_api.TF_FinishOperation(op_desc, status); if (status.Code != TF_Code.TF_OK) throw new Exception(status.Message); diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 46e388a8..3c8a8112 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -30,7 +30,7 @@ namespace Tensorflow _target = UTF8Encoding.UTF8.GetBytes(target); var opts = c_api.TF_NewSessionOptions(); var status = new Status(); - _session = c_api.TF_NewSession(_graph.Handle, opts, status.Handle); + _session = c_api.TF_NewSession(_graph, opts, status); c_api.TF_DeleteSessionOptions(opts); } @@ -40,30 +40,30 @@ namespace Tensorflow } - public virtual object run(Tensor fetches, FeedDict feed_dict = null) + public virtual object run(Tensor fetches, Dictionary feed_dict = null) { var result = _run(fetches, feed_dict); return result; } - private unsafe object _run(Tensor fetches, FeedDict feed_dict = null) + private unsafe object _run(Tensor fetches, Dictionary feed_dict = null) { - var feed_dict_tensor = new FeedDict(); + var feed_dict_tensor = new Dictionary(); if (feed_dict != null) { NDArray np_val = null; - foreach (FeedValue feed in feed_dict) + foreach (var feed in feed_dict) { - switch (feed.feed_val) + switch (feed.Value) { case float value: np_val = np.asarray(value); break; } - feed_dict_tensor[feed.feed] = np_val; + feed_dict_tensor[feed.Key] = np_val; } } @@ -85,9 +85,9 @@ namespace Tensorflow return fetch_handler.build_results(null, results); } - private object[] _do_run(List fetch_list, FeedDict feed_dict) + private object[] _do_run(List fetch_list, Dictionary feed_dict) { - var feeds = feed_dict.items().Select(x => new KeyValuePair(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); + var feeds = feed_dict.Select(x => new KeyValuePair(x.Key._as_tf_output(), new Tensor(x.Value as NDArray))).ToArray(); var fetches = fetch_list.Select(x => x._as_tf_output()).ToArray(); return _call_tf_sessionrun(feeds, fetches); @@ -113,7 +113,7 @@ namespace Tensorflow target_opers: new IntPtr[] { }, ntargets: 0, run_metadata: IntPtr.Zero, - status: status.Handle); + status: status); var result = output_values.Select(x => c_api.TF_TensorData(x)) .Select(x => (object)*(float*)x) diff --git a/src/TensorFlowNET.Core/Sessions/FeedDict.cs b/src/TensorFlowNET.Core/Sessions/FeedDict.cs deleted file mode 100644 index 7d36e899..00000000 --- a/src/TensorFlowNET.Core/Sessions/FeedDict.cs +++ /dev/null @@ -1,59 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Text; - -namespace Tensorflow -{ - public class FeedDict : IEnumerable - { - private Dictionary feed_dict; - - public FeedDict() - { - feed_dict = new Dictionary(); - } - - public object this[Tensor feed] - { - get - { - return feed_dict[feed]; - } - - set - { - feed_dict[feed] = value; - } - } - - public FeedDict Add(Tensor feed, object value) - { - feed_dict.Add(feed, value); - return this; - } - - public IEnumerator GetEnumerator() - { - foreach (KeyValuePair feed in feed_dict) - { - yield return new FeedValue - { - feed = feed.Key, - feed_val = feed.Value - }; - } - } - - public Dictionary items() - { - return feed_dict; - } - } - - public struct FeedValue - { - public Tensor feed { get; set; } - public object feed_val { get; set; } - } -} diff --git a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs index 94c8b2ed..347a1293 100644 --- a/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs +++ b/src/TensorFlowNET.Core/Sessions/_FetchHandler.cs @@ -15,7 +15,7 @@ namespace Tensorflow private List _final_fetches = new List(); private List _targets = new List(); - public _FetchHandler(Graph graph, Tensor fetches, FeedDict feeds = null, object feed_handles = null) + public _FetchHandler(Graph graph, Tensor fetches, Dictionary feeds = null, object feed_handles = null) { _fetch_mapper = new _FetchMapper().for_fetch(fetches); foreach(var fetch in _fetch_mapper.unique_fetches()) diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index dc369386..84a15aec 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -4,10 +4,13 @@ using System.Text; namespace Tensorflow { - public class Status : IDisposable + /// + /// TF_Status holds error information. It either has an OK code, or + /// else an error code with an associated error message. + /// + public class Status { private readonly IntPtr _handle; - public IntPtr Handle => _handle; /// /// Error message @@ -29,6 +32,23 @@ namespace Tensorflow c_api.TF_SetStatus(_handle, code, msg); } + /// + /// Check status + /// Throw exception with error message if code != TF_OK + /// + public void Check() + { + if(Code != TF_Code.TF_OK) + { + throw new Exception(Message); + } + } + + public static implicit operator IntPtr(Status status) + { + return status._handle; + } + public void Dispose() { c_api.TF_DeleteStatus(_handle); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 28a878f3..6e27b305 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -13,6 +13,8 @@ namespace Tensorflow /// public class Tensor { + public IntPtr Handle { get; } + public Graph graph => op.graph; public Operation op { get; } @@ -21,7 +23,6 @@ namespace Tensorflow public int value_index { get; } public TF_DataType dtype { get; } - public IntPtr handle { get; } public ulong bytesize { get; } public ulong dataTypeSize { get;} public ulong size => bytesize / dataTypeSize; @@ -45,7 +46,7 @@ namespace Tensorflow public Tensor(IntPtr handle) { - this.handle = handle; + Handle = handle; dtype = c_api.TF_TensorType(handle); rank = c_api.TF_NumDims(handle); bytesize = c_api.TF_TensorByteSize(handle); @@ -59,33 +60,52 @@ namespace Tensorflow public Tensor(NDArray nd) { - var data = Marshal.AllocHGlobal(sizeof(float) * nd.size); - Marshal.Copy(nd.Data(), 0, data, nd.size); - var dataType = ToTFDataType(nd.dtype); + Handle = Allocate(nd); + dtype = c_api.TF_TensorType(Handle); + rank = c_api.TF_NumDims(Handle); + bytesize = c_api.TF_TensorByteSize(Handle); + buffer = c_api.TF_TensorData(Handle); + dataTypeSize = c_api.TF_DataTypeSize(dtype); + + shape = new long[rank]; + for (int i = 0; i < rank; i++) + shape[i] = c_api.TF_Dim(Handle, i); + } + + private IntPtr Allocate(NDArray nd) + { + var dotHandle = Marshal.AllocHGlobal(nd.dtypesize * nd.size); + + switch (nd.dtype.Name) + { + case "Int32": + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); + break; + case "Single": + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); + break; + case "Double": + Marshal.Copy(nd.Data(), 0, dotHandle, nd.size); + break; + default: + throw new NotImplementedException("Marshal.Copy failed."); + } - var handle = c_api.TF_NewTensor(dataType, + var dataType = ToTFDataType(nd.dtype); + var tfHandle = c_api.TF_NewTensor(dataType, nd.shape.Select(x => (long)x).ToArray(), // shape nd.ndim, - data, - (UIntPtr)(nd.size * sizeof(float)), + dotHandle, + (UIntPtr)(nd.size * nd.dtypesize), (IntPtr values, IntPtr len, ref bool closure) => { // Free the original buffer and set flag - Marshal.FreeHGlobal(data); + Marshal.FreeHGlobal(dotHandle); closure = true; }, ref deallocator_called); - this.handle = handle; - dtype = c_api.TF_TensorType(handle); - rank = c_api.TF_NumDims(handle); - bytesize = c_api.TF_TensorByteSize(handle); - buffer = c_api.TF_TensorData(handle); - dataTypeSize = c_api.TF_DataTypeSize(dtype); - - shape = new long[rank]; - for (int i = 0; i < rank; i++) - shape[i] = c_api.TF_Dim(handle, i); + return tfHandle; } public Tensor(Operation op, int value_index, TF_DataType dtype) @@ -129,11 +149,20 @@ namespace Tensorflow { switch (type.Name) { + case "Int32": + return TF_DataType.TF_INT32; case "Single": return TF_DataType.TF_FLOAT; + case "Double": + return TF_DataType.TF_DOUBLE; } return TF_DataType.DtInvalid; } + + public static implicit operator IntPtr(Tensor tensor) + { + return tensor.Handle; + } } } diff --git a/src/TensorFlowNET.Core/c_api.cs b/src/TensorFlowNET.Core/c_api.cs index bb49a049..94e530c0 100644 --- a/src/TensorFlowNET.Core/c_api.cs +++ b/src/TensorFlowNET.Core/c_api.cs @@ -10,12 +10,22 @@ namespace Tensorflow /// /// The API leans towards simplicity and uniformity instead of convenience /// since most usage will be by language specific wrappers. + /// + /// The params type mapping between .net and c_api + /// TF_XX** => ref IntPtr (TF_Operation** op) => (ref IntPtr op) + /// TF_XX* => IntPtr (TF_Graph* graph) => (IntPtr graph) + /// struct => struct (TF_Output output) => (TF_Output output) + /// const char* => string + /// int32_t => int + /// int64_t* => long[] + /// size_t* => unlong[] + /// void* => IntPtr /// public static partial class c_api { public const string TensorFlowLibName = "tensorflow"; - public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocatorData); + public delegate void Deallocator(IntPtr data, IntPtr size, ref bool deallocator); [DllImport(TensorFlowLibName)] public static unsafe extern IntPtr TF_Version(); diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 01749733..7c3907dc 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -3,12 +3,21 @@ using System; using System.Collections.Generic; using System.Text; using Tensorflow; +using Buffer = Tensorflow.Buffer; namespace TensorFlowNET.UnitTest { [TestClass] public class OperationsTest { + [TestMethod] + public void GetAllOpList() + { + var handle = c_api.TF_GetAllOpList(); + var buffer = new Buffer(handle); + Assert.IsTrue(buffer.Length == buffer.Data.Length); + } + [TestMethod] public void addInPlaceholder() { @@ -18,9 +27,9 @@ namespace TensorFlowNET.UnitTest using(var sess = tf.Session()) { - var feed_dict = new FeedDict() - .Add(a, 3.0f) - .Add(b, 2.0f); + var feed_dict = new Dictionary(); + feed_dict.Add(a, 3.0f); + feed_dict.Add(b, 2.0f); var o = sess.run(c, feed_dict); } diff --git a/test/TensorFlowNET.UnitTest/TensorTest.cs b/test/TensorFlowNET.UnitTest/TensorTest.cs index f111e07c..bb7c8f27 100644 --- a/test/TensorFlowNET.UnitTest/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/TensorTest.cs @@ -13,7 +13,7 @@ namespace TensorFlowNET.UnitTest public class TensorTest { [TestMethod] - public unsafe void NewTensor() + public void NewTensor() { var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); @@ -27,5 +27,69 @@ namespace TensorFlowNET.UnitTest Assert.AreEqual(tensor.bytesize, (uint)nd.size * sizeof(float)); Assert.IsTrue(Enumerable.SequenceEqual(nd.Data(), array)); } + + /// + /// Port from tensorflow\c\c_api_test.cc + /// + [TestMethod] + public void SetShape() + { + var s = new Status(); + var graph = tf.get_default_graph(); + + var desc = c_api.TF_NewOperation(graph, "Placeholder", ""); + c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_FLOAT); + //if (!dims.empty()) + { + //TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); + } + var op = c_api.TF_FinishOperation(desc, s); + + Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.IsNotNull(op); + + // Fetch the shape, it should be completely unknown. + var feed_out_0 = new TF_Output { oper = op, index = 0 }; + int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); + + Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.AreEqual(-1, num_dims); + + // Set the shape to be unknown, expect no change. + c_api.TF_GraphSetTensorShape(graph, feed_out_0, new int[0], -1, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); + Assert.AreEqual(-1, num_dims); + + // Set the shape to be 2 x Unknown + var dims = new int[] { 2, -1 }; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); + Assert.AreEqual(2, num_dims); + + // Get the dimension vector appropriately. + var returned_dims = new int[dims.Length]; + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); + + // Set to a new valid shape: [2, 3] + dims[1] = 3; + c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); + //Assert.IsTrue(s.Code == TF_Code.TF_OK); + + // Fetch and see that the new value is returned. + c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); + //Assert.IsTrue(s.Code == TF_Code.TF_OK); + //Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); + + // Test for a scalar. + var three = c_test_util.ScalarConst(3, graph, s); + Assert.IsTrue(s.Code == TF_Code.TF_OK); + var three_out_0 = new TF_Output { oper = three.Handle }; + num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); + Assert.AreEqual(0, num_dims); + } } } diff --git a/test/TensorFlowNET.UnitTest/c_test_util.cs b/test/TensorFlowNET.UnitTest/c_test_util.cs new file mode 100644 index 00000000..add1913c --- /dev/null +++ b/test/TensorFlowNET.UnitTest/c_test_util.cs @@ -0,0 +1,37 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; +using Tensorflow; + +namespace TensorFlowNET.UnitTest +{ + public static class c_test_util + { + public static void ConstHelper(Tensor t, Graph graph, Status s, string name, ref IntPtr op) + { + var desc = c_api.TF_NewOperation(graph, "Const", name); + c_api.TF_SetAttrTensor(desc, "value", t.Handle, s); + s.Check(); + c_api.TF_SetAttrType(desc, "dtype", t.dtype); + op = c_api.TF_FinishOperation(desc, s); + s.Check(); + if(op == null) + { + throw new Exception("c_api.TF_FinishOperation failed."); + } + } + + public static Operation Const(Tensor t, Graph graph, Status s, string name) + { + IntPtr op = IntPtr.Zero; + ConstHelper(t, graph, s, name, ref op); + return new Operation(op); + } + + public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") + { + return Const(new Tensor(v), graph, s, name); + } + } +}