| @@ -139,6 +139,17 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TF_DataType TFE_TensorHandleDataType(IntPtr h); | public static extern TF_DataType TFE_TensorHandleDataType(IntPtr h); | ||||
| /// <summary> | |||||
| /// This function will block till the operation that produces `h` has | |||||
| /// completed. The memory returned might alias the internal memory used by | |||||
| /// TensorFlow. | |||||
| /// </summary> | |||||
| /// <param name="h">TFE_TensorHandle*</param> | |||||
| /// <param name="status">TF_Status*</param> | |||||
| /// <returns></returns> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern IntPtr TFE_TensorHandleResolve(IntPtr h, IntPtr status); | |||||
| /// <summary> | /// <summary> | ||||
| /// This function will block till the operation that produces `h` has completed. | /// This function will block till the operation that produces `h` has completed. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -156,5 +167,12 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, IntPtr status); | public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, IntPtr status); | ||||
| /// <summary> | |||||
| /// | |||||
| /// </summary> | |||||
| /// <param name="h">TFE_TensorHandle*</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_DeleteTensorHandle(IntPtr h); | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,6 +15,9 @@ namespace TensorFlowNET.UnitTest | |||||
| protected void EXPECT_EQ(object expected, object actual, string msg = "") | protected void EXPECT_EQ(object expected, object actual, string msg = "") | ||||
| => Assert.AreEqual(expected, actual, msg); | => Assert.AreEqual(expected, actual, msg); | ||||
| protected void CHECK_EQ(object expected, object actual, string msg = "") | |||||
| => Assert.AreEqual(expected, actual, msg); | |||||
| protected void EXPECT_NE(object expected, object actual, string msg = "") | protected void EXPECT_NE(object expected, object actual, string msg = "") | ||||
| => Assert.AreNotEqual(expected, actual, msg); | => Assert.AreNotEqual(expected, actual, msg); | ||||
| @@ -33,7 +33,7 @@ namespace TensorFlowNET.UnitTest.Eager | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| } | } | ||||
| c_api.TF_DeleteDeviceList(devices); | |||||
| // c_api.TF_DeleteDeviceList(devices); | |||||
| c_api.TF_DeleteStatus(status); | c_api.TF_DeleteStatus(status); | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,38 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using System; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Eager; | |||||
| using Buffer = System.Buffer; | |||||
| namespace TensorFlowNET.UnitTest.Eager | |||||
| { | |||||
| public partial class CApiEagerTest | |||||
| { | |||||
| /// <summary> | |||||
| /// TEST(CAPI, TensorHandle) | |||||
| /// </summary> | |||||
| [TestMethod] | |||||
| public unsafe void TensorHandle() | |||||
| { | |||||
| var h = TestMatrixTensorHandle(); | |||||
| EXPECT_EQ(TF_FLOAT, c_api.TFE_TensorHandleDataType(h)); | |||||
| var status = c_api.TF_NewStatus(); | |||||
| var t = c_api.TFE_TensorHandleResolve(h, status); | |||||
| ASSERT_EQ(16ul, c_api.TF_TensorByteSize(t)); | |||||
| var data = new float[] { 0f, 0f, 0f, 0f }; | |||||
| fixed (void* src = &data[0]) | |||||
| { | |||||
| Buffer.MemoryCopy((void*)c_api.TF_TensorData(t), src, data.Length * sizeof(float), (long)c_api.TF_TensorByteSize(t)); | |||||
| } | |||||
| EXPECT_EQ(1.0f, data[0]); | |||||
| EXPECT_EQ(2.0f, data[1]); | |||||
| EXPECT_EQ(3.0f, data[2]); | |||||
| EXPECT_EQ(4.0f, data[3]); | |||||
| c_api.TF_DeleteTensor(t); | |||||
| c_api.TFE_DeleteTensorHandle(h); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,6 +2,7 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Buffer = System.Buffer; | |||||
| namespace TensorFlowNET.UnitTest.Eager | namespace TensorFlowNET.UnitTest.Eager | ||||
| { | { | ||||
| @@ -11,7 +12,22 @@ namespace TensorFlowNET.UnitTest.Eager | |||||
| [TestClass] | [TestClass] | ||||
| public partial class CApiEagerTest : CApiTest | public partial class CApiEagerTest : CApiTest | ||||
| { | { | ||||
| unsafe IntPtr TestMatrixTensorHandle() | |||||
| { | |||||
| var dims = new long[] { 2, 2 }; | |||||
| var data = new float[] { 1.0f, 2.0f, 3.0f, 4.0f }; | |||||
| var t = c_api.TF_AllocateTensor(TF_FLOAT, dims, dims.Length, (ulong)data.Length * sizeof(float)); | |||||
| fixed(void *src = &data[0]) | |||||
| { | |||||
| Buffer.MemoryCopy(src, (void*)c_api.TF_TensorData(t), (long)c_api.TF_TensorByteSize(t), data.Length * sizeof(float)); | |||||
| } | |||||
| var status = c_api.TF_NewStatus(); | |||||
| var th = c_api.TFE_NewTensorHandle(t, status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| c_api.TF_DeleteTensor(t); | |||||
| c_api.TF_DeleteStatus(status); | |||||
| return th; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||