From 9dedbb5f0e9e223f25e405bbe062d850ebbd8cd3 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 7 Mar 2020 17:01:36 -0600 Subject: [PATCH] Eager.TensorHandle unit test. --- src/TensorFlowNET.Core/Eager/c_api.eager.cs | 18 +++++++++ test/TensorFlowNET.UnitTest/CApiTest.cs | 3 ++ .../Eager/CApi.Eager.Context.cs | 2 +- .../Eager/CApi.Eager.TensorHandle.cs | 38 +++++++++++++++++++ .../Eager/CApi.Eager.cs | 20 +++++++++- 5 files changed, 78 insertions(+), 3 deletions(-) create mode 100644 test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandle.cs diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 9e89d234..bb641505 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -139,6 +139,17 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern TF_DataType TFE_TensorHandleDataType(IntPtr h); + /// + /// This function will block till the operation that produces `h` has + /// completed. The memory returned might alias the internal memory used by + /// TensorFlow. + /// + /// TFE_TensorHandle* + /// TF_Status* + /// + [DllImport(TensorFlowLibName)] + public static extern IntPtr TFE_TensorHandleResolve(IntPtr h, IntPtr status); + /// /// This function will block till the operation that produces `h` has completed. /// @@ -156,5 +167,12 @@ namespace Tensorflow /// [DllImport(TensorFlowLibName)] public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, IntPtr status); + + /// + /// + /// + /// TFE_TensorHandle* + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteTensorHandle(IntPtr h); } } diff --git a/test/TensorFlowNET.UnitTest/CApiTest.cs b/test/TensorFlowNET.UnitTest/CApiTest.cs index 5a7c29f4..bd43e40b 100644 --- a/test/TensorFlowNET.UnitTest/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiTest.cs @@ -15,6 +15,9 @@ namespace TensorFlowNET.UnitTest protected void EXPECT_EQ(object expected, object actual, string msg = "") => Assert.AreEqual(expected, actual, msg); + protected void CHECK_EQ(object expected, object actual, string msg = "") + => Assert.AreEqual(expected, actual, msg); + protected void EXPECT_NE(object expected, object actual, string msg = "") => Assert.AreNotEqual(expected, actual, msg); diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs index 05d34d20..b966f0ea 100644 --- a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs +++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.Context.cs @@ -33,7 +33,7 @@ namespace TensorFlowNET.UnitTest.Eager EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); } - c_api.TF_DeleteDeviceList(devices); + // c_api.TF_DeleteDeviceList(devices); c_api.TF_DeleteStatus(status); } } diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandle.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandle.cs new file mode 100644 index 00000000..95e0d81b --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandle.cs @@ -0,0 +1,38 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using Tensorflow; +using Tensorflow.Eager; +using Buffer = System.Buffer; + +namespace TensorFlowNET.UnitTest.Eager +{ + public partial class CApiEagerTest + { + /// + /// TEST(CAPI, TensorHandle) + /// + [TestMethod] + public unsafe void TensorHandle() + { + var h = TestMatrixTensorHandle(); + EXPECT_EQ(TF_FLOAT, c_api.TFE_TensorHandleDataType(h)); + + var status = c_api.TF_NewStatus(); + var t = c_api.TFE_TensorHandleResolve(h, status); + ASSERT_EQ(16ul, c_api.TF_TensorByteSize(t)); + + var data = new float[] { 0f, 0f, 0f, 0f }; + fixed (void* src = &data[0]) + { + Buffer.MemoryCopy((void*)c_api.TF_TensorData(t), src, data.Length * sizeof(float), (long)c_api.TF_TensorByteSize(t)); + } + + EXPECT_EQ(1.0f, data[0]); + EXPECT_EQ(2.0f, data[1]); + EXPECT_EQ(3.0f, data[2]); + EXPECT_EQ(4.0f, data[3]); + c_api.TF_DeleteTensor(t); + c_api.TFE_DeleteTensorHandle(h); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs index c79ffc78..acd1e9b7 100644 --- a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs +++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs @@ -2,6 +2,7 @@ using System; using Tensorflow; using Tensorflow.Eager; +using Buffer = System.Buffer; namespace TensorFlowNET.UnitTest.Eager { @@ -11,7 +12,22 @@ namespace TensorFlowNET.UnitTest.Eager [TestClass] public partial class CApiEagerTest : CApiTest { - - + unsafe IntPtr TestMatrixTensorHandle() + { + var dims = new long[] { 2, 2 }; + var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; + var t = c_api.TF_AllocateTensor(TF_FLOAT, dims, dims.Length, (ulong)data.Length * sizeof(float)); + fixed(void *src = &data[0]) + { + Buffer.MemoryCopy(src, (void*)c_api.TF_TensorData(t), (long)c_api.TF_TensorByteSize(t), data.Length * sizeof(float)); + } + + var status = c_api.TF_NewStatus(); + var th = c_api.TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + c_api.TF_DeleteTensor(t); + c_api.TF_DeleteStatus(status); + return th; + } } }