| @@ -4,7 +4,7 @@ TensorFlow.NET provides .NET Standard binding for [TensorFlow](https://www.tenso | |||||
| [](https://gitter.im/sci-sharp/community) | [](https://gitter.im/sci-sharp/community) | ||||
|  |  | ||||
| TensorFlow.NET is a member project of SciSharp stack. | |||||
| TensorFlow.NET is a member project of [SciSharp](https://github.com/SciSharp) stack. | |||||
|  |  | ||||
| @@ -45,3 +45,5 @@ using(var sess = tf.Session()) | |||||
| var o = sess.run(c, feed_dict); | var o = sess.run(c, feed_dict); | ||||
| } | } | ||||
| ``` | ``` | ||||
| Star me or raise issue on [Github](https://github.com/SciSharp/TensorFlow.NET) feel free. | |||||
| @@ -13,7 +13,7 @@ namespace Tensorflow | |||||
| /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | /// then create a TensorFlow session to run parts of the graph across a set of local and remote devices. | ||||
| /// https://www.tensorflow.org/guide/graphs | /// https://www.tensorflow.org/guide/graphs | ||||
| /// </summary> | /// </summary> | ||||
| public class Graph | |||||
| public class Graph : IDisposable | |||||
| { | { | ||||
| private IntPtr _handle; | private IntPtr _handle; | ||||
| private Dictionary<int, Operation> _nodes_by_id; | private Dictionary<int, Operation> _nodes_by_id; | ||||
| @@ -25,6 +25,11 @@ namespace Tensorflow | |||||
| private string _name_stack; | private string _name_stack; | ||||
| public Graph() | |||||
| { | |||||
| _handle = c_api.TF_NewGraph(); | |||||
| } | |||||
| public Graph(IntPtr graph) | public Graph(IntPtr graph) | ||||
| { | { | ||||
| _handle = graph; | _handle = graph; | ||||
| @@ -171,6 +176,11 @@ namespace Tensorflow | |||||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | return _nodes_by_name.Values.Select(x => x).ToArray(); | ||||
| } | } | ||||
| public void Dispose() | |||||
| { | |||||
| c_api.TF_DeleteGraph(_handle); | |||||
| } | |||||
| public static implicit operator IntPtr(Graph graph) | public static implicit operator IntPtr(Graph graph) | ||||
| { | { | ||||
| return graph._handle; | return graph._handle; | ||||
| @@ -7,6 +7,14 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static partial class c_api | public static partial class c_api | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Destroy an options object. Graph will be deleted once no more | |||||
| /// TFSession's are referencing it. | |||||
| /// </summary> | |||||
| /// <param name="graph"></param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_DeleteGraph(IntPtr graph); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); | ||||
| @@ -21,14 +29,14 @@ namespace Tensorflow | |||||
| /// <param name="num_dims"></param> | /// <param name="num_dims"></param> | ||||
| /// <param name="status"></param> | /// <param name="status"></param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); | |||||
| public static extern void TF_GraphGetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sets the shape of the Tensor referenced by `output` in `graph` to | /// Sets the shape of the Tensor referenced by `output` in `graph` to | ||||
| /// the shape described by `dims` and `num_dims`. | /// the shape described by `dims` and `num_dims`. | ||||
| /// </summary> | /// </summary> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, int[] dims, int num_dims, IntPtr status); | |||||
| public static extern void TF_GraphSetTensorShape(IntPtr graph, TF_Output output, long[] dims, int num_dims, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the number of dimensions of the Tensor referenced by `output` | /// Returns the number of dimensions of the Tensor referenced by `output` | ||||
| @@ -14,16 +14,16 @@ namespace Tensorflow | |||||
| private Status status = new Status(); | private Status status = new Status(); | ||||
| public string name => c_api.TF_OperationName(_handle); | |||||
| public string optype => c_api.TF_OperationOpType(_handle); | |||||
| public string device => c_api.TF_OperationDevice(_handle); | |||||
| public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | |||||
| public TF_DataType OutputType => c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||||
| public int OutputListLength => c_api.TF_OperationOutputListLength(_handle, "output", status); | |||||
| public int NumInputs => c_api.TF_OperationNumInputs(_handle); | |||||
| public int NumConsumers => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||||
| public int NumControlInputs => c_api.TF_OperationNumControlInputs(_handle); | |||||
| public int NumControlOutputs => c_api.TF_OperationNumControlOutputs(_handle); | |||||
| public string name { get; } | |||||
| public string optype { get; } | |||||
| public string device { get; } | |||||
| public int NumOutputs { get; } | |||||
| public TF_DataType OutputType { get; } | |||||
| public int OutputListLength { get; } | |||||
| public int NumInputs { get; } | |||||
| public int NumConsumers { get; } | |||||
| public int NumControlInputs { get; } | |||||
| public int NumControlOutputs { get; } | |||||
| private Tensor[] _outputs; | private Tensor[] _outputs; | ||||
| public Tensor[] outputs => _outputs; | public Tensor[] outputs => _outputs; | ||||
| @@ -31,7 +31,21 @@ namespace Tensorflow | |||||
| public Operation(IntPtr handle) | public Operation(IntPtr handle) | ||||
| { | { | ||||
| if (handle == IntPtr.Zero) | |||||
| return; | |||||
| _handle = handle; | _handle = handle; | ||||
| name = c_api.TF_OperationName(_handle); | |||||
| optype = c_api.TF_OperationOpType(_handle); | |||||
| device = "";// c_api.TF_OperationDevice(_handle); | |||||
| NumOutputs = c_api.TF_OperationNumOutputs(_handle); | |||||
| OutputType = c_api.TF_OperationOutputType(new TF_Output(_handle, 0)); | |||||
| OutputListLength = c_api.TF_OperationOutputListLength(_handle, "output", status); | |||||
| NumInputs = c_api.TF_OperationNumInputs(_handle); | |||||
| NumConsumers = c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, 0)); | |||||
| NumControlInputs = c_api.TF_OperationNumControlInputs(_handle); | |||||
| NumControlOutputs = c_api.TF_OperationNumControlOutputs(_handle); | |||||
| } | } | ||||
| public Operation(Graph g, string opType, string oper_name) | public Operation(Graph g, string opType, string oper_name) | ||||
| @@ -14,7 +14,7 @@ namespace Tensorflow | |||||
| this.index = index; | this.index = index; | ||||
| } | } | ||||
| public IntPtr oper; | |||||
| public unsafe IntPtr oper; | |||||
| public int index; | public int index; | ||||
| } | } | ||||
| } | } | ||||
| @@ -22,6 +22,15 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_AddInput(IntPtr desc, TF_Output input); | public static extern void TF_AddInput(IntPtr desc, TF_Output input); | ||||
| /// <summary> | |||||
| /// For inputs that take a list of tensors. | |||||
| /// inputs must point to TF_Output[num_inputs]. | |||||
| /// </summary> | |||||
| /// <param name="desc"></param> | |||||
| /// <param name="inputs"></param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TF_AddInputList(IntPtr desc, TF_Output[] inputs, int num_inputs); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | public static extern IntPtr TF_FinishOperation(IntPtr desc, IntPtr status); | ||||
| @@ -11,7 +11,7 @@ namespace Tensorflow | |||||
| /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | ||||
| /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | ||||
| /// </summary> | /// </summary> | ||||
| public class Tensor | |||||
| public class Tensor : IDisposable | |||||
| { | { | ||||
| private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
| @@ -38,6 +38,7 @@ namespace Tensorflow | |||||
| /// n n-Tensor (you get the idea) | /// n n-Tensor (you get the idea) | ||||
| /// </summary> | /// </summary> | ||||
| public int rank; | public int rank; | ||||
| public int NDims => rank; | |||||
| /// <summary> | /// <summary> | ||||
| /// if original buffer is free. | /// if original buffer is free. | ||||
| @@ -96,7 +97,7 @@ namespace Tensorflow | |||||
| nd.shape.Select(x => (long)x).ToArray(), // shape | nd.shape.Select(x => (long)x).ToArray(), // shape | ||||
| nd.ndim, | nd.ndim, | ||||
| dotHandle, | dotHandle, | ||||
| (UIntPtr)(nd.size * nd.dtypesize), | |||||
| (ulong)(nd.size * nd.dtypesize), | |||||
| (IntPtr values, IntPtr len, ref bool closure) => | (IntPtr values, IntPtr len, ref bool closure) => | ||||
| { | { | ||||
| // Free the original buffer and set flag | // Free the original buffer and set flag | ||||
| @@ -160,9 +161,19 @@ namespace Tensorflow | |||||
| return TF_DataType.DtInvalid; | return TF_DataType.DtInvalid; | ||||
| } | } | ||||
| public void Dispose() | |||||
| { | |||||
| c_api.TF_DeleteTensor(_handle); | |||||
| } | |||||
| public static implicit operator IntPtr(Tensor tensor) | public static implicit operator IntPtr(Tensor tensor) | ||||
| { | { | ||||
| return tensor._handle; | return tensor._handle; | ||||
| } | } | ||||
| public static implicit operator Tensor(IntPtr handle) | |||||
| { | |||||
| return new Tensor(handle); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -7,20 +7,23 @@ namespace Tensorflow | |||||
| { | { | ||||
| public static partial class c_api | public static partial class c_api | ||||
| { | { | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, long[] dims, int num_dims, ulong len); | |||||
| /// <summary> | /// <summary> | ||||
| /// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. | /// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="dt"></param> | /// <param name="dt"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern ulong TF_DataTypeSize(TF_DataType dt); | |||||
| public static extern ulong TF_DataTypeSize(TF_DataType dt); | |||||
| /// <summary> | /// <summary> | ||||
| /// Destroy a tensor. | /// Destroy a tensor. | ||||
| /// </summary> | /// </summary> | ||||
| /// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static unsafe extern void TF_DeleteTensor(IntPtr tensor); | |||||
| public static extern void TF_DeleteTensor(IntPtr tensor); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return the length of the tensor in the "dim_index" dimension. | /// Return the length of the tensor in the "dim_index" dimension. | ||||
| @@ -30,7 +33,7 @@ namespace Tensorflow | |||||
| /// <param name="dim_index"></param> | /// <param name="dim_index"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe long TF_Dim(IntPtr tensor, int dim_index); | |||||
| public static extern long TF_Dim(IntPtr tensor, int dim_index); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return a new tensor that holds the bytes data[0,len-1] | /// Return a new tensor that holds the bytes data[0,len-1] | ||||
| @@ -44,7 +47,7 @@ namespace Tensorflow | |||||
| /// <param name="deallocator_arg"></param> | /// <param name="deallocator_arg"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, UIntPtr len, Deallocator deallocator, ref bool deallocator_arg); | |||||
| public static extern IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len, Deallocator deallocator, ref bool deallocator_arg); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return the number of dimensions that the tensor has. | /// Return the number of dimensions that the tensor has. | ||||
| @@ -52,7 +55,7 @@ namespace Tensorflow | |||||
| /// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe int TF_NumDims(IntPtr tensor); | |||||
| public static extern int TF_NumDims(IntPtr tensor); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return the size of the underlying data in bytes. | /// Return the size of the underlying data in bytes. | ||||
| @@ -60,7 +63,7 @@ namespace Tensorflow | |||||
| /// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe ulong TF_TensorByteSize(IntPtr tensor); | |||||
| public static extern ulong TF_TensorByteSize(IntPtr tensor); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return a pointer to the underlying data buffer. | /// Return a pointer to the underlying data buffer. | ||||
| @@ -68,7 +71,7 @@ namespace Tensorflow | |||||
| /// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe IntPtr TF_TensorData(IntPtr tensor); | |||||
| public static extern IntPtr TF_TensorData(IntPtr tensor); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return the type of a tensor element. | /// Return the type of a tensor element. | ||||
| @@ -76,6 +79,6 @@ namespace Tensorflow | |||||
| /// <param name="tensor"></param> | /// <param name="tensor"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe TF_DataType TF_TensorType(IntPtr tensor); | |||||
| public static extern TF_DataType TF_TensorType(IntPtr tensor); | |||||
| } | } | ||||
| } | } | ||||
| @@ -7,6 +7,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// C API for TensorFlow. | /// C API for TensorFlow. | ||||
| /// Port from tensorflow\c\c_api.h | |||||
| /// | /// | ||||
| /// The API leans towards simplicity and uniformity instead of convenience | /// The API leans towards simplicity and uniformity instead of convenience | ||||
| /// since most usage will be by language specific wrappers. | /// since most usage will be by language specific wrappers. | ||||
| @@ -9,17 +9,21 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestClass] | [TestClass] | ||||
| public class GraphTest | public class GraphTest | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Port from c_api_test.cc | |||||
| /// `TEST(CAPI, Graph)` | |||||
| /// </summary> | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Graph() | |||||
| public void c_api_Graph() | |||||
| { | { | ||||
| var s = new Status(); | var s = new Status(); | ||||
| var graph = tf.get_default_graph(); | |||||
| var graph = new Graph(); | |||||
| // Make a placeholder operation. | // Make a placeholder operation. | ||||
| var feed = c_test_util.Placeholder(graph, s); | var feed = c_test_util.Placeholder(graph, s); | ||||
| Assert.AreEqual("feed", feed.name); | Assert.AreEqual("feed", feed.name); | ||||
| Assert.AreEqual("Placeholder", feed.optype); | Assert.AreEqual("Placeholder", feed.optype); | ||||
| //Assert.AreEqual("", feed.device); | |||||
| Assert.AreEqual("", feed.device); | |||||
| Assert.AreEqual(1, feed.NumOutputs); | Assert.AreEqual(1, feed.NumOutputs); | ||||
| Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); | Assert.AreEqual(TF_DataType.TF_INT32, feed.OutputType); | ||||
| Assert.AreEqual(1, feed.OutputListLength); | Assert.AreEqual(1, feed.OutputListLength); | ||||
| @@ -30,6 +34,19 @@ namespace TensorFlowNET.UnitTest | |||||
| AttrValue attr_value = null; | AttrValue attr_value = null; | ||||
| c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s); | c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s); | ||||
| Assert.AreEqual(attr_value.Type, DataType.DtInt32); | |||||
| // Test not found errors in TF_Operation*() query functions. | |||||
| // Assert.AreEqual(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s)); | |||||
| // Assert.AreEqual(TF_Code.TF_INVALID_ARGUMENT, s.Code); | |||||
| // Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s)); | |||||
| // Assert.AreEqual("Operation 'feed' has no attr named 'missing'.", s.Message); | |||||
| // Make a constant oper with the scalar "3". | |||||
| var three = c_test_util.ScalarConst(3, graph, s); | |||||
| // Add oper. | |||||
| var add = c_test_util.Add(feed, three, graph, s); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -10,8 +10,12 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestClass] | [TestClass] | ||||
| public class OperationsTest | public class OperationsTest | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Port from tensorflow\c\c_api_test.cc | |||||
| /// `TEST(CAPI, GetAllOpList)` | |||||
| /// </summary> | |||||
| [TestMethod] | [TestMethod] | ||||
| public void GetAllOpList() | |||||
| public void c_api_GetAllOpList() | |||||
| { | { | ||||
| var handle = c_api.TF_GetAllOpList(); | var handle = c_api.TF_GetAllOpList(); | ||||
| var buffer = new Buffer(handle); | var buffer = new Buffer(handle); | ||||
| @@ -12,8 +12,29 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestClass] | [TestClass] | ||||
| public class TensorTest | public class TensorTest | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Port from c_api_test.cc | |||||
| /// `TEST(CAPI, AllocateTensor)` | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public void c_api_AllocateTensor() | |||||
| { | |||||
| ulong num_bytes = 6 * sizeof(float); | |||||
| long[] dims = { 2, 3 }; | |||||
| Tensor t = c_api.TF_AllocateTensor(TF_DataType.TF_FLOAT, dims, 2, num_bytes); | |||||
| Assert.AreEqual(TF_DataType.TF_FLOAT, t.dtype); | |||||
| Assert.AreEqual(2, t.NDims); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, t.shape)); | |||||
| Assert.AreEqual(num_bytes, t.bytesize); | |||||
| t.Dispose(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Port from c_api_test.cc | |||||
| /// `TEST(CAPI, Tensor)` | |||||
| /// </summary> | |||||
| [TestMethod] | [TestMethod] | ||||
| public void NewTensor() | |||||
| public void c_api_Tensor() | |||||
| { | { | ||||
| var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | var nd = np.array(1f, 2f, 3f, 4f, 5f, 6f).reshape(2, 3); | ||||
| @@ -30,46 +51,38 @@ namespace TensorFlowNET.UnitTest | |||||
| /// <summary> | /// <summary> | ||||
| /// Port from tensorflow\c\c_api_test.cc | /// Port from tensorflow\c\c_api_test.cc | ||||
| /// `TEST(CAPI, SetShape)` | |||||
| /// </summary> | /// </summary> | ||||
| [TestMethod] | [TestMethod] | ||||
| public void SetShape() | |||||
| public void c_api_SetShape() | |||||
| { | { | ||||
| var s = new Status(); | var s = new Status(); | ||||
| var graph = tf.get_default_graph(); | |||||
| var graph = new Graph(); | |||||
| var desc = c_api.TF_NewOperation(graph, "Placeholder", ""); | |||||
| c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_FLOAT); | |||||
| //if (!dims.empty()) | |||||
| { | |||||
| //TF_SetAttrShape(desc, "shape", dims.data(), dims.size()); | |||||
| } | |||||
| var op = c_api.TF_FinishOperation(desc, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.IsNotNull(op); | |||||
| var feed = c_test_util.Placeholder(graph, s); | |||||
| var feed_out_0 = new TF_Output(feed, 0); | |||||
| // Fetch the shape, it should be completely unknown. | // Fetch the shape, it should be completely unknown. | ||||
| var feed_out_0 = new TF_Output { oper = op, index = 0 }; | |||||
| int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | int num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| Assert.AreEqual(-1, num_dims); | Assert.AreEqual(-1, num_dims); | ||||
| // Set the shape to be unknown, expect no change. | // Set the shape to be unknown, expect no change. | ||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, new int[0], -1, s); | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | ||||
| Assert.AreEqual(-1, num_dims); | Assert.AreEqual(-1, num_dims); | ||||
| // Set the shape to be 2 x Unknown | // Set the shape to be 2 x Unknown | ||||
| var dims = new int[] { 2, -1 }; | |||||
| long[] dims = { 2, -1 }; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | num_dims = c_api.TF_GraphGetTensorNumDims(graph, feed_out_0, s); | ||||
| Assert.AreEqual(2, num_dims); | Assert.AreEqual(2, num_dims); | ||||
| // Get the dimension vector appropriately. | // Get the dimension vector appropriately. | ||||
| var returned_dims = new int[dims.Length]; | |||||
| var returned_dims = new long[dims.Length]; | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | ||||
| @@ -77,19 +90,57 @@ namespace TensorFlowNET.UnitTest | |||||
| // Set to a new valid shape: [2, 3] | // Set to a new valid shape: [2, 3] | ||||
| dims[1] = 3; | dims[1] = 3; | ||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, dims.Length, s); | ||||
| //Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| // Fetch and see that the new value is returned. | // Fetch and see that the new value is returned. | ||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | ||||
| //Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| //Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.IsTrue(Enumerable.SequenceEqual(dims, returned_dims)); | |||||
| // Try to set 'unknown' with unknown rank on the shape and see that | |||||
| // it doesn't change. | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, null, -1, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.AreEqual(2, num_dims); | |||||
| Assert.AreEqual(2, returned_dims[0]); | |||||
| Assert.AreEqual(3, returned_dims[1]); | |||||
| // Try to set 'unknown' with same rank on the shape and see that | |||||
| // it doesn't change. | |||||
| dims[0] = -1; | |||||
| dims[1] = -1; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, num_dims, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.AreEqual(2, num_dims); | |||||
| Assert.AreEqual(2, returned_dims[0]); | |||||
| Assert.AreEqual(3, returned_dims[1]); | |||||
| // Try to fetch a shape with the wrong num_dims | |||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, returned_dims, 5, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||||
| // Try to set an invalid shape (cannot change 2x3 to a 2x5). | |||||
| dims[1] = 5; | |||||
| c_api.TF_GraphSetTensorShape(graph, feed_out_0, dims, 2, s); | |||||
| Assert.IsTrue(s.Code == TF_Code.TF_INVALID_ARGUMENT); | |||||
| // Test for a scalar. | // Test for a scalar. | ||||
| var three = c_test_util.ScalarConst(3, graph, s); | var three = c_test_util.ScalarConst(3, graph, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | Assert.IsTrue(s.Code == TF_Code.TF_OK); | ||||
| var three_out_0 = new TF_Output { oper = three }; | |||||
| var three_out_0 = new TF_Output(three, 0); | |||||
| num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | num_dims = c_api.TF_GraphGetTensorNumDims(graph, three_out_0, s); | ||||
| Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| Assert.AreEqual(0, num_dims); | Assert.AreEqual(0, num_dims); | ||||
| c_api.TF_GraphGetTensorShape(graph, feed_out_0, null, num_dims, s); | |||||
| //Assert.IsTrue(s.Code == TF_Code.TF_OK); | |||||
| graph.Dispose(); | |||||
| s.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -7,8 +7,32 @@ using Buffer = Tensorflow.Buffer; | |||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| { | { | ||||
| /// <summary> | |||||
| /// Port from `tensorflow\c\c_test_util.cc` | |||||
| /// </summary> | |||||
| public static class c_test_util | public static class c_test_util | ||||
| { | { | ||||
| public static Operation Add(Operation l, Operation r, Graph graph, Status s, string name = "add") | |||||
| { | |||||
| Operation op = null; | |||||
| AddOpHelper(l, r, graph, s, name, ref op, true); | |||||
| return op; | |||||
| } | |||||
| public static void AddOpHelper(Operation l, Operation r, Graph graph, Status s, string name, ref Operation op, bool check) | |||||
| { | |||||
| var desc = c_api.TF_NewOperation(graph, "AddN", name); | |||||
| c_api.TF_AddInputList(desc, new TF_Output[] | |||||
| { | |||||
| new TF_Output(l, 0), | |||||
| new TF_Output(r, 0), | |||||
| }, 2); | |||||
| op = c_api.TF_FinishOperation(desc, s); | |||||
| s.Check(); | |||||
| } | |||||
| public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | public static bool GetAttrValue(Operation oper, string attr_name, ref AttrValue attr_value, Status s) | ||||
| { | { | ||||
| var buffer = new Buffer(); | var buffer = new Buffer(); | ||||
| @@ -58,7 +82,7 @@ namespace TensorFlowNET.UnitTest | |||||
| return op; | return op; | ||||
| } | } | ||||
| public static Operation ScalarConst(int v, Graph graph, Status s, string name = "Const") | |||||
| public static Operation ScalarConst(int v, Graph graph, Status s, string name = "scalar") | |||||
| { | { | ||||
| return Const(new Tensor(v), graph, s, name); | return Const(new Tensor(v), graph, s, name); | ||||
| } | } | ||||