diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 7fd81af8..7da0a7e3 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -1,5 +1,6 @@ using System; using System.Runtime.InteropServices; +using TFE_Executor = System.IntPtr; namespace Tensorflow { @@ -196,5 +197,55 @@ namespace Tensorflow /// TFE_TensorHandle* [DllImport(TensorFlowLibName)] public static extern void TFE_DeleteTensorHandle(IntPtr h); + + /// + /// Creates a new eager Executor. Nodes in one executor are guaranteed to be + /// executed in sequence. Assigning nodes to different executors allows executing + /// nodes in parallel. + /// + /// + /// TFE_Executor* + [DllImport(TensorFlowLibName)] + public static extern IntPtr TFE_NewExecutor(bool is_async); + + /// + /// Deletes the eager Executor without waiting for enqueued nodes. Please call + /// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to + /// make sure all nodes are finished. + /// + /// TFE_Executor* + [DllImport(TensorFlowLibName)] + public static extern void TFE_DeleteExecutor(IntPtr executor); + + /// + /// Causes the calling thread to block till all ops dispatched in this executor + /// have been executed. Note that "execution" here refers to kernel execution / + /// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee + /// that lower level device queues (like GPU streams) have been flushed. + /// + /// This call may not block for execution of ops enqueued concurrently with this + /// call. + /// + /// TFE_Executor* + /// TF_Status* + [DllImport(TensorFlowLibName)] + public static extern void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor executor, IntPtr status); + + /// + /// Sets a custom Executor for current thread. All nodes created by this thread + /// will be added to this Executor. It will override current executor. + /// + /// + /// + [DllImport(TensorFlowLibName)] + public static extern void TFE_ContextSetExecutorForThread(IntPtr ctx, TFE_Executor executor); + + /// + /// Returns the Executor for current thread. + /// + /// + /// TFE_Executor* + [DllImport(TensorFlowLibName)] + public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); } } diff --git a/test/TensorFlowNET.UnitTest/CApiTest.cs b/test/TensorFlowNET.UnitTest/CApiTest.cs index 45ac9b1b..fc5121ad 100644 --- a/test/TensorFlowNET.UnitTest/CApiTest.cs +++ b/test/TensorFlowNET.UnitTest/CApiTest.cs @@ -100,6 +100,15 @@ namespace TensorFlowNET.UnitTest protected void TFE_DeleteOp(IntPtr op) => c_api.TFE_DeleteOp(op); + protected void TFE_DeleteExecutor(IntPtr executor) + => c_api.TFE_DeleteExecutor(executor); + + protected IntPtr TFE_ContextGetExecutorForThread(IntPtr ctx) + => c_api.TFE_ContextGetExecutorForThread(ctx); + + protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, IntPtr status) + => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); + protected IntPtr TFE_TensorHandleResolve(IntPtr h, IntPtr status) => c_api.TFE_TensorHandleResolve(h, status); @@ -127,6 +136,9 @@ namespace TensorFlowNET.UnitTest protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, IntPtr status) => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); + protected void TFE_OpSetDevice(IntPtr op, string device_name, IntPtr status) + => c_api.TFE_OpSetDevice(op, device_name, status); + protected unsafe void memcpy(void * src, IntPtr dst, ulong size) { Buffer.MemoryCopy(src, dst.ToPointer(), size, size); diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs index eae59b08..2f0b9dc3 100644 --- a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs +++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs @@ -36,8 +36,36 @@ namespace TensorFlowNET.UnitTest.Eager var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status); ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); - // shape_op = ShapeOp(ctx, hgpu); + var shape_op = ShapeOp(ctx, hgpu); + TFE_OpSetDevice(shape_op, gpu_device_name, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); + var retvals = new IntPtr[1]; + int num_retvals = 1; + c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status)); + + // .device of shape is GPU since the op is executed on GPU + device_name = TFE_TensorHandleDeviceName(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + ASSERT_TRUE(device_name.Contains("GPU:0")); + + // .backing_device of shape is CPU since the tensor is backed by CPU + backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + ASSERT_TRUE(backing_device_name.Contains("CPU:0")); + + TFE_DeleteOp(shape_op); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); } + + TFE_DeleteTensorHandle(hcpu); + // not export api + /*var executor = TFE_ContextGetExecutorForThread(ctx); + TFE_ExecutorWaitForAllPendingNodes(executor, status); + ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_DeleteExecutor(executor);*/ + TFE_DeleteContext(ctx); } } } diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs index f941c3c4..0a9db179 100644 --- a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs +++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs @@ -1,8 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using Tensorflow; -using Tensorflow.Eager; -using Buffer = System.Buffer; namespace TensorFlowNET.UnitTest.Eager { @@ -67,5 +65,19 @@ namespace TensorFlowNET.UnitTest.Eager TF_DeleteDeviceList(devices); return false; } + + IntPtr ShapeOp(IntPtr ctx, IntPtr a) + { + var status = TF_NewStatus(); + + var op = TFE_NewOp(ctx, "Shape", status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TFE_OpAddInput(op, a, status); + CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a)); + + return op; + } } } diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj index a84ad93c..b97ae1c1 100644 --- a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj +++ b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj @@ -33,7 +33,7 @@ - +