Implement SafeContextOptionsHandle as a wrapper for TFE_ContextOptionstags/v0.20
| @@ -16,7 +16,7 @@ namespace Tensorflow.Eager | |||||
| public Context(ContextOptions opts, Status status) | public Context(ContextOptions opts, Status status) | ||||
| { | { | ||||
| Handle = c_api.TFE_NewContext(opts, status.Handle); | |||||
| Handle = c_api.TFE_NewContext(opts.Handle, status.Handle); | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| @@ -1,26 +1,33 @@ | |||||
| using System; | |||||
| using System.IO; | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| namespace Tensorflow.Eager | namespace Tensorflow.Eager | ||||
| { | { | ||||
| public class ContextOptions : DisposableObject | |||||
| public sealed class ContextOptions : IDisposable | |||||
| { | { | ||||
| public ContextOptions() : base(c_api.TFE_NewContextOptions()) | |||||
| { } | |||||
| public SafeContextOptionsHandle Handle { get; } | |||||
| /// <summary> | |||||
| /// Dispose any unmanaged resources related to given <paramref name="handle"/>. | |||||
| /// </summary> | |||||
| protected sealed override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TFE_DeleteContextOptions(_handle); | |||||
| public ContextOptions() | |||||
| { | |||||
| Handle = c_api.TFE_NewContextOptions(); | |||||
| } | |||||
| public static implicit operator IntPtr(ContextOptions opts) | |||||
| => opts._handle; | |||||
| public static implicit operator TFE_ContextOptions(ContextOptions opts) | |||||
| => new TFE_ContextOptions(opts._handle); | |||||
| public void Dispose() | |||||
| => Handle.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,40 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public sealed class SafeContextOptionsHandle : SafeTensorflowHandle | |||||
| { | |||||
| public SafeContextOptionsHandle() | |||||
| { | |||||
| } | |||||
| public SafeContextOptionsHandle(IntPtr handle) | |||||
| : base(handle) | |||||
| { | |||||
| } | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| c_api.TFE_DeleteContextOptions(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,23 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Eager | |||||
| { | |||||
| public struct TFE_ContextOptions | |||||
| { | |||||
| IntPtr _handle; | |||||
| public TFE_ContextOptions(IntPtr handle) | |||||
| => _handle = handle; | |||||
| public static implicit operator TFE_ContextOptions(IntPtr handle) | |||||
| => new TFE_ContextOptions(handle); | |||||
| public static implicit operator IntPtr(TFE_ContextOptions tensor) | |||||
| => tensor._handle; | |||||
| public override string ToString() | |||||
| => $"TFE_ContextOptions {_handle}"; | |||||
| } | |||||
| } | |||||
| @@ -52,7 +52,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <returns>TFE_ContextOptions*</returns> | /// <returns>TFE_ContextOptions*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern TFE_ContextOptions TFE_NewContextOptions(); | |||||
| public static extern SafeContextOptionsHandle TFE_NewContextOptions(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Destroy an options object. | /// Destroy an options object. | ||||
| @@ -114,7 +114,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns>TFE_Context*</returns> | /// <returns>TFE_Context*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status); | |||||
| public static extern SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_ContextStartStep(SafeContextHandle ctx); | public static extern void TFE_ContextStartStep(SafeContextHandle ctx); | ||||
| @@ -254,7 +254,7 @@ namespace Tensorflow | |||||
| /// <param name="opts">TFE_ContextOptions*</param> | /// <param name="opts">TFE_ContextOptions*</param> | ||||
| /// <param name="enable">unsigned char</param> | /// <param name="enable">unsigned char</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_ContextOptionsSetAsync(IntPtr opts, byte enable); | |||||
| public static extern void TFE_ContextOptionsSetAsync(SafeContextOptionsHandle opts, byte enable); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| @@ -102,15 +102,12 @@ namespace TensorFlowNET.UnitTest | |||||
| protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | protected void TFE_Execute(IntPtr op, IntPtr[] retvals, ref int num_retvals, SafeStatusHandle status) | ||||
| => c_api.TFE_Execute(op, retvals, ref num_retvals, status); | => c_api.TFE_Execute(op, retvals, ref num_retvals, status); | ||||
| protected IntPtr TFE_NewContextOptions() | |||||
| protected SafeContextOptionsHandle TFE_NewContextOptions() | |||||
| => c_api.TFE_NewContextOptions(); | => c_api.TFE_NewContextOptions(); | ||||
| protected SafeContextHandle TFE_NewContext(IntPtr opts, SafeStatusHandle status) | |||||
| protected SafeContextHandle TFE_NewContext(SafeContextOptionsHandle opts, SafeStatusHandle status) | |||||
| => c_api.TFE_NewContext(opts, status); | => c_api.TFE_NewContext(opts, status); | ||||
| protected void TFE_DeleteContextOptions(IntPtr opts) | |||||
| => c_api.TFE_DeleteContextOptions(opts); | |||||
| protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status) | protected int TFE_OpGetInputLength(IntPtr op, string input_name, SafeStatusHandle status) | ||||
| => c_api.TFE_OpGetInputLength(op, input_name, status); | => c_api.TFE_OpGetInputLength(op, input_name, status); | ||||
| @@ -1,6 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
| { | { | ||||
| @@ -13,13 +14,16 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| public void Context() | public void Context() | ||||
| { | { | ||||
| using var status = c_api.TF_NewStatus(); | using var status = c_api.TF_NewStatus(); | ||||
| var opts = c_api.TFE_NewContextOptions(); | |||||
| IntPtr devices; | |||||
| using (var ctx = c_api.TFE_NewContext(opts, status)) | |||||
| static SafeContextHandle NewContext(SafeStatusHandle status) | |||||
| { | { | ||||
| c_api.TFE_DeleteContextOptions(opts); | |||||
| using var opts = c_api.TFE_NewContextOptions(); | |||||
| return c_api.TFE_NewContext(opts, status); | |||||
| } | |||||
| IntPtr devices; | |||||
| using (var ctx = NewContext(status)) | |||||
| { | |||||
| devices = c_api.TFE_ContextListDevices(ctx, status); | devices = c_api.TFE_ContextListDevices(ctx, status); | ||||
| EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | EXPECT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
| @@ -19,14 +20,18 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| unsafe void Execute_MatMul_CPU(bool async) | unsafe void Execute_MatMul_CPU(bool async) | ||||
| { | { | ||||
| using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
| var opts = TFE_NewContextOptions(); | |||||
| c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); | |||||
| static SafeContextHandle NewContext(bool async, SafeStatusHandle status) | |||||
| { | |||||
| using var opts = c_api.TFE_NewContextOptions(); | |||||
| c_api.TFE_ContextOptionsSetAsync(opts, Convert.ToByte(async)); | |||||
| return c_api.TFE_NewContext(opts, status); | |||||
| } | |||||
| IntPtr t; | IntPtr t; | ||||
| using (var ctx = TFE_NewContext(opts, status)) | |||||
| using (var ctx = NewContext(async, status)) | |||||
| { | { | ||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_DeleteContextOptions(opts); | |||||
| var m = TestMatrixTensorHandle(); | var m = TestMatrixTensorHandle(); | ||||
| var matmul = MatMulOp(ctx, m, m); | var matmul = MatMulOp(ctx, m, m); | ||||
| @@ -2,7 +2,6 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Buffer = System.Buffer; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
| { | { | ||||
| @@ -15,10 +14,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| public unsafe void OpGetInputAndOutputLengths() | public unsafe void OpGetInputAndOutputLengths() | ||||
| { | { | ||||
| using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
| var opts = TFE_NewContextOptions(); | |||||
| using var ctx = TFE_NewContext(opts, status); | |||||
| static SafeContextHandle NewContext(SafeStatusHandle status) | |||||
| { | |||||
| using var opts = c_api.TFE_NewContextOptions(); | |||||
| return c_api.TFE_NewContext(opts, status); | |||||
| } | |||||
| using var ctx = NewContext(status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_DeleteContextOptions(opts); | |||||
| var input1 = TestMatrixTensorHandle(); | var input1 = TestMatrixTensorHandle(); | ||||
| var input2 = TestMatrixTensorHandle(); | var input2 = TestMatrixTensorHandle(); | ||||
| @@ -1,10 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Buffer = System.Buffer; | |||||
| using System.Linq; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
| { | { | ||||
| @@ -17,10 +14,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| public unsafe void OpInferMixedTypeInputListAttrs() | public unsafe void OpInferMixedTypeInputListAttrs() | ||||
| { | { | ||||
| using var status = TF_NewStatus(); | using var status = TF_NewStatus(); | ||||
| var opts = TFE_NewContextOptions(); | |||||
| using var ctx = TFE_NewContext(opts, status); | |||||
| static SafeContextHandle NewContext(SafeStatusHandle status) | |||||
| { | |||||
| using var opts = c_api.TFE_NewContextOptions(); | |||||
| return c_api.TFE_NewContext(opts, status); | |||||
| } | |||||
| using var ctx = NewContext(status); | |||||
| CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | CHECK_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_DeleteContextOptions(opts); | |||||
| var condition = TestScalarTensorHandle(true); | var condition = TestScalarTensorHandle(true); | ||||
| var t1 = TestMatrixTensorHandle(); | var t1 = TestMatrixTensorHandle(); | ||||
| @@ -2,7 +2,6 @@ | |||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Buffer = System.Buffer; | |||||
| namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
| { | { | ||||
| @@ -15,9 +14,14 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| public unsafe void TensorHandleDevices() | public unsafe void TensorHandleDevices() | ||||
| { | { | ||||
| var status = c_api.TF_NewStatus(); | var status = c_api.TF_NewStatus(); | ||||
| var opts = TFE_NewContextOptions(); | |||||
| using var ctx = TFE_NewContext(opts, status); | |||||
| TFE_DeleteContextOptions(opts); | |||||
| static SafeContextHandle NewContext(SafeStatusHandle status) | |||||
| { | |||||
| using var opts = c_api.TFE_NewContextOptions(); | |||||
| return c_api.TFE_NewContext(opts, status); | |||||
| } | |||||
| using var ctx = NewContext(status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| var hcpu = TestMatrixTensorHandle(); | var hcpu = TestMatrixTensorHandle(); | ||||
| @@ -1,6 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace TensorFlowNET.UnitTest.NativeAPI | namespace TensorFlowNET.UnitTest.NativeAPI | ||||
| @@ -14,10 +15,15 @@ namespace TensorFlowNET.UnitTest.NativeAPI | |||||
| public unsafe void Variables() | public unsafe void Variables() | ||||
| { | { | ||||
| using var status = c_api.TF_NewStatus(); | using var status = c_api.TF_NewStatus(); | ||||
| var opts = TFE_NewContextOptions(); | |||||
| using var ctx = TFE_NewContext(opts, status); | |||||
| static SafeContextHandle NewContext(SafeStatusHandle status) | |||||
| { | |||||
| using var opts = c_api.TFE_NewContextOptions(); | |||||
| return c_api.TFE_NewContext(opts, status); | |||||
| } | |||||
| using var ctx = NewContext(status); | |||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||
| TFE_DeleteContextOptions(opts); | |||||
| var var_handle = CreateVariable(ctx, 12.0f, status); | var var_handle = CreateVariable(ctx, 12.0f, status); | ||||
| ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ASSERT_EQ(TF_OK, TF_GetCode(status), TF_Message(status)); | ||||