Implement SafeContextHandle as a wrapper for TFE_Contexttags/v0.20
| @@ -16,6 +16,7 @@ | |||||
| using System; | using System; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Eager; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -64,7 +65,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns>TFE_TensorHandle*</returns> | /// <returns>TFE_TensorHandle*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, SafeStatusHandle status); | |||||
| public static extern IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | /// Retrieves the full name of the device (e.g. /job:worker/replica:0/...) | ||||
| @@ -2,7 +2,7 @@ | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| public class Context : DisposableObject | |||||
| public sealed class Context : IDisposable | |||||
| { | { | ||||
| public const int GRAPH_MODE = 0; | public const int GRAPH_MODE = 0; | ||||
| public const int EAGER_MODE = 1; | public const int EAGER_MODE = 1; | ||||
| @@ -12,9 +12,11 @@ namespace Tensorflow.Eager | |||||
| public string scope_name = ""; | public string scope_name = ""; | ||||
| bool _initialized = false; | bool _initialized = false; | ||||
| public SafeContextHandle Handle { get; } | |||||
| public Context(ContextOptions opts, Status status) | public Context(ContextOptions opts, Status status) | ||||
| { | { | ||||
| _handle = c_api.TFE_NewContext(opts, status.Handle); | |||||
| Handle = c_api.TFE_NewContext(opts, status.Handle); | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| @@ -29,16 +31,10 @@ namespace Tensorflow.Eager | |||||
| } | } | ||||
| public void start_step() | public void start_step() | ||||
| => c_api.TFE_ContextStartStep(_handle); | |||||
| => c_api.TFE_ContextStartStep(Handle); | |||||
| public void end_step() | public void end_step() | ||||
| => c_api.TFE_ContextEndStep(_handle); | |||||
| /// <summary> | |||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
| /// </summary> | |||||
| protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TFE_DeleteContext(_handle); | |||||
| => c_api.TFE_ContextEndStep(Handle); | |||||
| public bool executing_eagerly() | public bool executing_eagerly() | ||||
| => default_execution_mode == EAGER_MODE; | => default_execution_mode == EAGER_MODE; | ||||
| @@ -48,10 +44,7 @@ namespace Tensorflow.Eager | |||||
| name : | name : | ||||
| "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | "cd2c89b7-88b7-44c8-ad83-06c2a9158347"; | ||||
| public static implicit operator IntPtr(Context ctx) | |||||
| => ctx._handle; | |||||
| public static implicit operator TFE_Context(Context ctx) | |||||
| => new TFE_Context(ctx._handle); | |||||
| public void Dispose() | |||||
| => Handle.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -53,7 +53,7 @@ namespace Tensorflow.Eager | |||||
| { | { | ||||
| object value = null; | object value = null; | ||||
| byte isList = 0; | byte isList = 0; | ||||
| var attrType = c_api.TFE_OpNameGetAttrType(tf.context, Name, attr_name, ref isList, tf.status.Handle); | |||||
| var attrType = c_api.TFE_OpNameGetAttrType(tf.context.Handle, Name, attr_name, ref isList, tf.status.Handle); | |||||
| switch (attrType) | switch (attrType) | ||||
| { | { | ||||
| case TF_AttrType.TF_ATTR_BOOL: | case TF_AttrType.TF_ATTR_BOOL: | ||||
| @@ -0,0 +1,40 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public sealed class SafeContextHandle : SafeTensorflowHandle | |||||
| { | |||||
| public SafeContextHandle() | |||||
| { | |||||
| } | |||||
| public SafeContextHandle(IntPtr handle) | |||||
| : base(handle) | |||||
| { | |||||
| } | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| c_api.TFE_DeleteContext(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,23 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public struct TFE_Context | |||||
| { | |||||
| IntPtr _handle; | |||||
| public TFE_Context(IntPtr handle) | |||||
| => _handle = handle; | |||||
| public static implicit operator TFE_Context(IntPtr handle) | |||||
| => new TFE_Context(handle); | |||||
| public static implicit operator IntPtr(TFE_Context tensor) | |||||
| => tensor._handle; | |||||
| public override string ToString() | |||||
| => $"TFE_Context {_handle}"; | |||||
| } | |||||
| } | |||||
| @@ -73,7 +73,7 @@ namespace Tensorflow | |||||
| public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); | public static extern TF_AttrType TFE_OpGetAttrType(IntPtr op, string attr_name, ref byte is_list, SafeStatusHandle status); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TF_AttrType TFE_OpNameGetAttrType(IntPtr ct, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | |||||
| public static extern TF_AttrType TFE_OpNameGetAttrType(SafeContextHandle ctx, string op_or_function_name, string attr_name, ref byte is_list, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the length (number of tensors) of the input argument `input_name` | /// Returns the length (number of tensors) of the input argument `input_name` | ||||
| @@ -114,13 +114,13 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns>TFE_Context*</returns> | /// <returns>TFE_Context*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Context TFE_NewContext(IntPtr opts, SafeStatusHandle status); | |||||
| public static extern SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Context TFE_ContextStartStep(IntPtr ctx); | |||||
| public static extern void TFE_ContextStartStep(SafeContextHandle ctx); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Context TFE_ContextEndStep(IntPtr ctx); | |||||
| public static extern void TFE_ContextEndStep(SafeContextHandle ctx); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -148,7 +148,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Op TFE_NewOp(IntPtr ctx, string op_or_function_name, SafeStatusHandle status); | |||||
| public static extern TFE_Op TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | /// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This | ||||
| @@ -317,7 +317,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TFE_ContextListDevices(IntPtr ctx, SafeStatusHandle status); | |||||
| public static extern IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -379,7 +379,7 @@ namespace Tensorflow | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <param name="executor"></param> | /// <param name="executor"></param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_ContextSetExecutorForThread(IntPtr ctx, TFE_Executor executor); | |||||
| public static extern void TFE_ContextSetExecutorForThread(SafeContextHandle ctx, TFE_Executor executor); | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the Executor for current thread. | /// Returns the Executor for current thread. | ||||
| @@ -387,7 +387,7 @@ namespace Tensorflow | |||||
| /// <param name="ctx"></param> | /// <param name="ctx"></param> | ||||
| /// <returns>TFE_Executor*</returns> | /// <returns>TFE_Executor*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_Executor TFE_ContextGetExecutorForThread(IntPtr ctx); | |||||
| public static extern TFE_Executor TFE_ContextGetExecutorForThread(SafeContextHandle ctx); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -402,7 +402,7 @@ namespace Tensorflow | |||||
| /// <param name="status"></param> | /// <param name="status"></param> | ||||
| /// <returns>EagerTensorHandle</returns> | /// <returns>EagerTensorHandle</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern SafeStatusHandle TFE_FastPathExecute(IntPtr ctx, | |||||
| public static extern SafeStatusHandle TFE_FastPathExecute(SafeContextHandle ctx, | |||||
| string device_name, | string device_name, | ||||
| string op_name, | string op_name, | ||||
| string name, | string name, | ||||
| @@ -416,7 +416,7 @@ namespace Tensorflow | |||||
| public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); | public delegate void TFE_FastPathExecute_SetOpAttrs(IntPtr op); | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern SafeStatusHandle TFE_QuickExecute(IntPtr ctx, | |||||
| public static extern SafeStatusHandle TFE_QuickExecute(SafeContextHandle ctx, | |||||
| string device_name, | string device_name, | ||||
| string op_name, | string op_name, | ||||
| IntPtr[] inputs, | IntPtr[] inputs, | ||||
| @@ -1,6 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | |||||
| using Buffer = System.Buffer; | using Buffer = System.Buffer; | ||||
| namespace TensorFlowNET.UnitTest | namespace TensorFlowNET.UnitTest | ||||
| @@ -92,7 +93,7 @@ namespace TensorFlowNET.UnitTest | |||||
| protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length) | protected void TFE_OpSetAttrString(IntPtr op, string attr_name, string value, uint length) | ||||
| => c_api.TFE_OpSetAttrString(op, attr_name, value, length); | => c_api.TFE_OpSetAttrString(op, attr_name, value, length); | ||||
| protected IntPtr TFE_NewOp(IntPtr ctx, string op_or_function_name, SafeStatusHandle status) | |||||
| protected IntPtr TFE_NewOp(SafeContextHandle ctx, string op_or_function_name, SafeStatusHandle status) | |||||
| => c_api.TFE_NewOp(ctx, op_or_function_name, status); | => c_api.TFE_NewOp(ctx, op_or_function_name, status); | ||||
| protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | protected IntPtr TFE_NewTensorHandle(IntPtr t, SafeStatusHandle status) | ||||
| @@ -104,10 +105,7 @@ namespace TensorFlowNET.UnitTest | |||||
| protected IntPtr TFE_NewContextOptions() | protected IntPtr TFE_NewContextOptions() | ||||
| => c_api.TFE_NewContextOptions(); | => c_api.TFE_NewContextOptions(); | ||||
| protected void TFE_DeleteContext(IntPtr t) | |||||
| => c_api.TFE_DeleteContext(t); | |||||
| protected IntPtr TFE_NewContext(IntPtr opts, SafeStatusHandle status) | |||||
| protected SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status) | |||||
| => c_api.TFE_NewContext(opts, status); | => c_api.TFE_NewContext(opts, status); | ||||
| protected void TFE_DeleteContextOptions(IntPtr opts) | protected void TFE_DeleteContextOptions(IntPtr opts) | ||||
| @@ -131,7 +129,7 @@ namespace TensorFlowNET.UnitTest | |||||
| protected void TFE_DeleteExecutor(IntPtr executor) | protected void TFE_DeleteExecutor(IntPtr executor) | ||||
| => c_api.TFE_DeleteExecutor(executor); | => c_api.TFE_DeleteExecutor(executor); | ||||
| protected IntPtr TFE_ContextGetExecutorForThread(IntPtr ctx) | |||||
| protected IntPtr TFE_ContextGetExecutorForThread(SafeContextHandle ctx) | |||||
| => c_api.TFE_ContextGetExecutorForThread(ctx); | => c_api.TFE_ContextGetExecutorForThread(ctx); | ||||
| protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status) | protected void TFE_ExecutorWaitForAllPendingNodes(IntPtr executor, SafeStatusHandle status) | ||||
| @@ -146,7 +144,7 @@ namespace TensorFlowNET.UnitTest | |||||
| protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) | protected string TFE_TensorHandleBackingDeviceName(IntPtr h, SafeStatusHandle status) | ||||
| => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | => c_api.StringPiece(c_api.TFE_TensorHandleBackingDeviceName(h, status)); | ||||
| protected IntPtr TFE_ContextListDevices(IntPtr ctx, SafeStatusHandle status) | |||||
| protected IntPtr TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status) | |||||
| => c_api.TFE_ContextListDevices(ctx, status); | => c_api.TFE_ContextListDevices(ctx, status); | ||||
| protected int TF_DeviceListCount(IntPtr list) | protected int TF_DeviceListCount(IntPtr list) | ||||
| @@ -161,7 +159,7 @@ namespace TensorFlowNET.UnitTest | |||||
| protected void TF_DeleteDeviceList(IntPtr list) | protected void TF_DeleteDeviceList(IntPtr list) | ||||
| => c_api.TF_DeleteDeviceList(list); | => c_api.TF_DeleteDeviceList(list); | ||||
| protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, IntPtr ctx, string device_name, SafeStatusHandle status) | |||||
| protected IntPtr TFE_TensorHandleCopyToDevice(IntPtr h, SafeContextHandle ctx, string device_name, SafeStatusHandle status) | |||||
| => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | => c_api.TFE_TensorHandleCopyToDevice(h, ctx, device_name, status); | ||||
| protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status) | protected void TFE_OpSetDevice(IntPtr op, string device_name, SafeStatusHandle status) | ||||
| @@ -1,7 +1,6 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
| { | { | ||||
| @@ -15,14 +14,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| { | { | ||||
| using var status = c_api.TF_NewStatus(); | using var status = c_api.TF_NewStatus(); | ||||
| var opts = c_api.TFE_NewContextOptions(); | var opts = c_api.TFE_NewContextOptions(); | ||||
| var ctx = c_api.TFE_NewContext(opts, status); | |||||
| c_api.TFE_DeleteContextOptions(opts); | |||||
| IntPtr devices; | |||||
| using (var ctx = c_api.TFE_NewContext(opts, status)) | |||||
| { | |||||
| c_api.TFE_DeleteContextOptions(opts); | |||||
| var devices = c_api.TFE_ContextListDevices(ctx, status); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| devices = c_api.TFE_ContextListDevices(ctx, status); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| } | |||||
| c_api.TFE_DeleteContext(ctx); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| int num_devices = c_api.TF_DeviceListCount(devices); | int num_devices = c_api.TF_DeviceListCount(devices); | ||||
| @@ -21,24 +21,28 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
| var opts = TFE_NewContextOptions(); | var opts = TFE_NewContextOptions(); | ||||
| c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); | c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); | ||||
| var ctx = TFE_NewContext(opts, status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_DeleteContextOptions(opts); | |||||
| var m = TestMatrixTensorHandle(); | |||||
| var matmul = MatMulOp(ctx, m, m); | |||||
| var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | |||||
| int num_retvals = 2; | |||||
| c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||||
| EXPECT_EQ(1, num_retvals); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_DeleteOp(matmul); | |||||
| TFE_DeleteTensorHandle(m); | |||||
| IntPtr t; | |||||
| using (var ctx = TFE_NewContext(opts, status)) | |||||
| { | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_DeleteContextOptions(opts); | |||||
| var m = TestMatrixTensorHandle(); | |||||
| var matmul = MatMulOp(ctx, m, m); | |||||
| var retvals = new IntPtr[] { IntPtr.Zero, IntPtr.Zero }; | |||||
| int num_retvals = 2; | |||||
| c_api.TFE_Execute(matmul, retvals, ref num_retvals, status); | |||||
| EXPECT_EQ(1, num_retvals); | |||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_DeleteOp(matmul); | |||||
| TFE_DeleteTensorHandle(m); | |||||
| t = TFE_TensorHandleResolve(retvals[0], status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_DeleteTensorHandle(retvals[0]); | |||||
| } | |||||
| var t = TFE_TensorHandleResolve(retvals[0], status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | |||||
| TFE_DeleteTensorHandle(retvals[0]); | |||||
| TFE_DeleteContext(ctx); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| var product = new float[4]; | var product = new float[4]; | ||||
| EXPECT_EQ(product.Length * sizeof(float), (int)TF_TensorByteSize(t)); | EXPECT_EQ(product.Length * sizeof(float), (int)TF_TensorByteSize(t)); | ||||
| @@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| { | { | ||||
| using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
| var opts = TFE_NewContextOptions(); | var opts = TFE_NewContextOptions(); | ||||
| var ctx = TFE_NewContext(opts, status); | |||||
| using var ctx = TFE_NewContext(opts, status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_DeleteContextOptions(opts); | TFE_DeleteContextOptions(opts); | ||||
| @@ -57,7 +57,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| TFE_DeleteTensorHandle(input2); | TFE_DeleteTensorHandle(input2); | ||||
| TFE_DeleteTensorHandle(retvals[0]); | TFE_DeleteTensorHandle(retvals[0]); | ||||
| TFE_DeleteTensorHandle(retvals[1]); | TFE_DeleteTensorHandle(retvals[1]); | ||||
| TFE_DeleteContext(ctx); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -18,7 +18,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| { | { | ||||
| using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
| var opts = TFE_NewContextOptions(); | var opts = TFE_NewContextOptions(); | ||||
| var ctx = TFE_NewContext(opts, status); | |||||
| using var ctx = TFE_NewContext(opts, status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_DeleteContextOptions(opts); | TFE_DeleteContextOptions(opts); | ||||
| @@ -50,7 +50,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| TFE_DeleteTensorHandle(t1); | TFE_DeleteTensorHandle(t1); | ||||
| TFE_DeleteTensorHandle(t2); | TFE_DeleteTensorHandle(t2); | ||||
| TFE_DeleteTensorHandle(retvals[0]); | TFE_DeleteTensorHandle(retvals[0]); | ||||
| TFE_DeleteContext(ctx); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -16,7 +16,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| { | { | ||||
| var status = c_api.TF_NewStatus(); | var status = c_api.TF_NewStatus(); | ||||
| var opts = TFE_NewContextOptions(); | var opts = TFE_NewContextOptions(); | ||||
| var ctx = TFE_NewContext(opts, status); | |||||
| using var ctx = TFE_NewContext(opts, status); | |||||
| TFE_DeleteContextOptions(opts); | TFE_DeleteContextOptions(opts); | ||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| @@ -65,7 +65,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| TFE_ExecutorWaitForAllPendingNodes(executor, status); | TFE_ExecutorWaitForAllPendingNodes(executor, status); | ||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_DeleteExecutor(executor); | TFE_DeleteExecutor(executor); | ||||
| TFE_DeleteContext(ctx); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -15,7 +15,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| { | { | ||||
| using var status = c_api.TF_NewStatus(); | using var status = c_api.TF_NewStatus(); | ||||
| var opts = TFE_NewContextOptions(); | var opts = TFE_NewContextOptions(); | ||||
| var ctx = TFE_NewContext(opts, status); | |||||
| using var ctx = TFE_NewContext(opts, status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_DeleteContextOptions(opts); | TFE_DeleteContextOptions(opts); | ||||
| @@ -47,7 +47,6 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| TFE_DeleteTensorHandle(var_handle); | TFE_DeleteTensorHandle(var_handle); | ||||
| TFE_DeleteTensorHandle(value_handle[0]); | TFE_DeleteTensorHandle(value_handle[0]); | ||||
| TFE_DeleteContext(ctx); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
| @@ -25,7 +26,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| return th; | return th; | ||||
| } | } | ||||
| IntPtr MatMulOp(IntPtr ctx, IntPtr a, IntPtr b) | |||||
| IntPtr MatMulOp(SafeContextHandle ctx, IntPtr a, IntPtr b) | |||||
| { | { | ||||
| using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
| @@ -40,7 +41,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| return op; | return op; | ||||
| } | } | ||||
| bool GetDeviceName(IntPtr ctx, ref string device_name, string device_type) | |||||
| bool GetDeviceName(SafeContextHandle ctx, ref string device_name, string device_type) | |||||
| { | { | ||||
| var status = TF_NewStatus(); | var status = TF_NewStatus(); | ||||
| var devices = TFE_ContextListDevices(ctx, status); | var devices = TFE_ContextListDevices(ctx, status); | ||||
| @@ -65,7 +66,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| return false; | return false; | ||||
| } | } | ||||
| IntPtr ShapeOp(IntPtr ctx, IntPtr a) | |||||
| IntPtr ShapeOp(SafeContextHandle ctx, IntPtr a) | |||||
| { | { | ||||
| using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
| @@ -78,7 +79,7 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| return op; | return op; | ||||
| } | } | ||||
| unsafe IntPtr CreateVariable(IntPtr ctx, float value, SafeStatusHandle status) | |||||
| unsafe IntPtr CreateVariable(SafeContextHandle ctx, float value, SafeStatusHandle status) | |||||
| { | { | ||||
| var op = TFE_NewOp(ctx, "VarHandleOp", status); | var op = TFE_NewOp(ctx, "VarHandleOp", status); | ||||
| if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | if (TF_GetCode(status) != TF_OK) return IntPtr.Zero; | ||||