diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
index 7fd81af8..7da0a7e3 100644
--- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs
+++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
@@ -1,5 +1,6 @@
using System;
using System.Runtime.InteropServices;
+using TFE_Executor = System.IntPtr;
namespace Tensorflow
{
@@ -196,5 +197,55 @@ namespace Tensorflow
/// TFE_TensorHandle*
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteTensorHandle(IntPtr h);
+
+ ///
+ /// Creates a new eager Executor. Nodes in one executor are guaranteed to be
+ /// executed in sequence. Assigning nodes to different executors allows executing
+ /// nodes in parallel.
+ ///
+ ///
+ /// TFE_Executor*
+ [DllImport(TensorFlowLibName)]
+ public static extern IntPtr TFE_NewExecutor(bool is_async);
+
+ ///
+ /// Deletes the eager Executor without waiting for enqueued nodes. Please call
+ /// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to
+ /// make sure all nodes are finished.
+ ///
+ /// TFE_Executor*
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_DeleteExecutor(IntPtr executor);
+
+ ///
+ /// Causes the calling thread to block till all ops dispatched in this executor
+ /// have been executed. Note that "execution" here refers to kernel execution /
+ /// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
+ /// that lower level device queues (like GPU streams) have been flushed.
+ ///
+ /// This call may not block for execution of ops enqueued concurrently with this
+ /// call.
+ ///
+ /// TFE_Executor*
+ /// TF_Status*
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor executor, IntPtr status);
+
+ ///
+ /// Sets a custom Executor for current thread. All nodes created by this thread
+ /// will be added to this Executor. It will override current executor.
+ ///
+ ///
+ ///
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_ContextSetExecutorForThread(IntPtr ctx, TFE_Executor executor);
+
+ ///
+ /// Returns the Executor for current thread.
+ ///
+ ///
+ /// TFE_Executor*
+ [DllImport(TensorFlowLibName)]
+ public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx);
}
}
diff --git a/test/TensorFlowNET.UnitTest/CApiTest.cs b/test/TensorFlowNET.UnitTest/CApiTest.cs
index 45ac9b1b..fc5121ad 100644
--- a/test/TensorFlowNET.UnitTest/CApiTest.cs
+++ b/test/TensorFlowNET.UnitTest/CApiTest.cs
@@ -100,6 +100,15 @@ namespace TensorFlowNET.UnitTest
protected void TFE_DeleteOp(IntPtr op)
=> c_api.TFE_DeleteOp(op);
+ protected void TFE_DeleteExecutor(IntPtr executor)
+ => c_api.TFE_DeleteExecutor(executor);
+
+ protected IntPtr TFE_ContextGetExecutorForThread(IntPtr ctx)
+ => c_api.TFE_ContextGetExecutorForThread(ctx);
+
+ protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, IntPtr status)
+ => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status);
+
protected IntPtr TFE_TensorHandleResolve(IntPtr h, IntPtr status)
=> c_api.TFE_TensorHandleResolve(h, status);
@@ -127,6 +136,9 @@ namespace TensorFlowNET.UnitTest
protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, IntPtr status)
=> c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status);
+ protected void TFE_OpSetDevice(IntPtr op, string device_name, IntPtr status)
+ => c_api.TFE_OpSetDevice(op, device_name, status);
+
protected unsafe void memcpy(void * src, IntPtr dst, ulong size)
{
Buffer.MemoryCopy(src, dst.ToPointer(), size, size);
diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs
index eae59b08..2f0b9dc3 100644
--- a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs
+++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.TensorHandleDevices.cs
@@ -36,8 +36,36 @@ namespace TensorFlowNET.UnitTest.Eager
var hgpu = TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_device_name, status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
- // shape_op = ShapeOp(ctx, hgpu);
+ var shape_op = ShapeOp(ctx, hgpu);
+ TFE_OpSetDevice(shape_op, gpu_device_name, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
+ var retvals = new IntPtr[1];
+ int num_retvals = 1;
+ c_api.TFE_Execute(shape_op, retvals, ref num_retvals, status);
+ ASSERT_TRUE(TF_GetCode(status) == TF_OK, TF_Message(status));
+
+ // .device of shape is GPU since the op is executed on GPU
+ device_name = TFE_TensorHandleDeviceName(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
+ ASSERT_TRUE(device_name.Contains("GPU:0"));
+
+ // .backing_device of shape is CPU since the tensor is backed by CPU
+ backing_device_name = TFE_TensorHandleBackingDeviceName(retvals[0], status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
+ ASSERT_TRUE(backing_device_name.Contains("CPU:0"));
+
+ TFE_DeleteOp(shape_op);
+ TFE_DeleteTensorHandle(retvals[0]);
+ TFE_DeleteTensorHandle(hgpu);
}
+
+ TFE_DeleteTensorHandle(hcpu);
+ // not export api
+ /*var executor = TFE_ContextGetExecutorForThread(ctx);
+ TFE_ExecutorWaitForAllPendingNodes(executor, status);
+ ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
+ TFE_DeleteExecutor(executor);*/
+ TFE_DeleteContext(ctx);
}
}
}
diff --git a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs
index f941c3c4..0a9db179 100644
--- a/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs
+++ b/test/TensorFlowNET.UnitTest/Eager/CApi.Eager.cs
@@ -1,8 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using Tensorflow;
-using Tensorflow.Eager;
-using Buffer = System.Buffer;
namespace TensorFlowNET.UnitTest.Eager
{
@@ -67,5 +65,19 @@ namespace TensorFlowNET.UnitTest.Eager
TF_DeleteDeviceList(devices);
return false;
}
+
+ IntPtr ShapeOp(IntPtr ctx, IntPtr a)
+ {
+ var status = TF_NewStatus();
+
+ var op = TFE_NewOp(ctx, "Shape", status);
+ CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
+ TFE_OpAddInput(op, a, status);
+ CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status));
+ TF_DeleteStatus(status);
+ TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(a));
+
+ return op;
+ }
}
}
diff --git a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj
index a84ad93c..b97ae1c1 100644
--- a/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj
+++ b/test/TensorFlowNET.UnitTest/Tensorflow.UnitTest.csproj
@@ -33,7 +33,7 @@
-
+