| @@ -0,0 +1,40 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public sealed class SafeExecutorHandle : SafeTensorflowHandle | |||||
| { | |||||
| private SafeExecutorHandle() | |||||
| { | |||||
| } | |||||
| public SafeExecutorHandle(IntPtr handle) | |||||
| : base(handle) | |||||
| { | |||||
| } | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| c_api.TFE_DeleteExecutor(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,23 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public struct TFE_Executor | |||||
| { | |||||
| IntPtr _handle; | |||||
| public TFE_Executor(IntPtr handle) | |||||
| => _handle = handle; | |||||
| public static implicit operator TFE_Executor(IntPtr handle) | |||||
| => new TFE_Executor(handle); | |||||
| public static implicit operator IntPtr(TFE_Executor tensor) | |||||
| => tensor._handle; | |||||
| public override string ToString() | |||||
| => $"TFE_Executor {_handle}"; | |||||
| } | |||||
| } | |||||
| @@ -2,7 +2,6 @@ | |||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Device; | using Tensorflow.Device; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using TFE_Executor = System.IntPtr; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -299,7 +298,7 @@ namespace Tensorflow | |||||
| /// <param name="is_async"></param> | /// <param name="is_async"></param> | ||||
| /// <returns>TFE_Executor*</returns> | /// <returns>TFE_Executor*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Executor TFE_NewExecutor(bool is_async); | |||||
| public static extern SafeExecutorHandle TFE_NewExecutor(bool is_async); | |||||
| /// <summary> | /// <summary> | ||||
| /// Deletes the eager Executor without waiting for enqueued nodes. Please call | /// Deletes the eager Executor without waiting for enqueued nodes. Please call | ||||
| @@ -322,7 +321,7 @@ namespace Tensorflow | |||||
| /// <param name="executor">TFE_Executor*</param> | /// <param name="executor">TFE_Executor*</param> | ||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor executor, SafeStatusHandle status); | |||||
| public static extern void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Sets a custom Executor for current thread. All nodes created by this thread | /// Sets a custom Executor for current thread. All nodes created by this thread | ||||
| @@ -331,7 +330,7 @@ namespace Tensorflow | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="executor"></param> | /// <param name="executor"></param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, TFE_Executor executor); | |||||
| public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, SafeExecutorHandle executor); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the Executor for current thread. | /// Returns the Executor for current thread. | ||||
| @@ -339,7 +338,7 @@ namespace Tensorflow | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <returns>TFE_Executor*</returns> | /// <returns>TFE_Executor*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | |||||
| public static extern SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | public static extern IntPtr TFE_TapeSetNew(bool persistent, bool watch_accessed_variables); | ||||
| @@ -124,13 +124,10 @@ namespace TensorFlowNET.UnitTest | |||||
| protected void TFE_DeleteOp(IntPtr op) | protected void TFE_DeleteOp(IntPtr op) | ||||
| => c_api.TFE_DeleteOp(op); | => c_api.TFE_DeleteOp(op); | ||||
| protected void TFE_DeleteExecutor(IntPtr executor) | |||||
| => c_api.TFE_DeleteExecutor(executor); | |||||
| protected IntPtr TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | |||||
| protected SafeExecutorHandle TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | |||||
| => c_api.TFE_ContextGetExecutorForThread(ctx); | => c_api.TFE_ContextGetExecutorForThread(ctx); | ||||
| protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status) | |||||
| protected void TFE_ExecutorWaitForAllPendingNodes(SafeExecutorHandle executor, SafeStatusHandle status) | |||||
| => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); | => c_api.TFE_ExecutorWaitForAllPendingNodes(executor, status); | ||||
| protected IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status) | protected IntPtr TFE_TensorHandleResolve(IntPtr h, SafeStatusHandle status) | ||||
| @@ -65,10 +65,9 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| TFE_DeleteTensorHandle(hcpu); | TFE_DeleteTensorHandle(hcpu); | ||||
| // not export api | // not export api | ||||
| var executor = TFE_ContextGetExecutorForThread(ctx); | |||||
| using var executor = TFE_ContextGetExecutorForThread(ctx); | |||||
| TFE_ExecutorWaitForAllPendingNodes(executor, status); | TFE_ExecutorWaitForAllPendingNodes(executor, status); | ||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_DeleteExecutor(executor); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||