diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
index 888c4fea..971e5ddf 100644
--- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
+++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs
@@ -287,7 +287,7 @@ namespace Tensorflow
/// TF_Status*
///
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_LoadSessionFromSavedModel(IntPtr session_options, IntPtr run_options,
+ public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options,
string export_dir, string[] tags, int tags_len,
IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status);
diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
index 8f8e3f9f..76d434b4 100644
--- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs
+++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
@@ -47,7 +47,7 @@ namespace Tensorflow
lock (Locks.ProcessWide)
{
status = status ?? new Status();
- _handle = c_api.TF_NewSession(_graph, opts, status.Handle);
+ _handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle);
status.Check(true);
}
}
diff --git a/src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs b/src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs
new file mode 100644
index 00000000..95109cce
--- /dev/null
+++ b/src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs
@@ -0,0 +1,40 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using Tensorflow.Util;
+
+namespace Tensorflow
+{
+ public sealed class SafeSessionOptionsHandle : SafeTensorflowHandle
+ {
+ public SafeSessionOptionsHandle()
+ {
+ }
+
+ public SafeSessionOptionsHandle(IntPtr handle)
+ : base(handle)
+ {
+ }
+
+ protected override bool ReleaseHandle()
+ {
+ c_api.TF_DeleteSessionOptions(handle);
+ SetHandle(IntPtr.Zero);
+ return true;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs
index bda3acbd..6de2c606 100644
--- a/src/TensorFlowNET.Core/Sessions/Session.cs
+++ b/src/TensorFlowNET.Core/Sessions/Session.cs
@@ -55,7 +55,7 @@ namespace Tensorflow
IntPtr sess;
try
{
- sess = c_api.TF_LoadSessionFromSavedModel(opt,
+ sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle,
IntPtr.Zero,
path,
tags,
@@ -66,7 +66,7 @@ namespace Tensorflow
status.Check(true);
} catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel"))
{
- sess = c_api.TF_LoadSessionFromSavedModel(opt,
+ sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle,
IntPtr.Zero,
Path.GetFullPath(path),
tags,
diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs
index 56f13628..00923d14 100644
--- a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs
+++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs
@@ -16,44 +16,36 @@
using Google.Protobuf;
using System;
-using System.Runtime.InteropServices;
namespace Tensorflow
{
- internal class SessionOptions : DisposableObject
+ internal sealed class SessionOptions : IDisposable
{
+ public SafeSessionOptionsHandle Handle { get; }
+
public SessionOptions(string target = "", ConfigProto config = null)
{
- _handle = c_api.TF_NewSessionOptions();
- c_api.TF_SetTarget(_handle, target);
+ Handle = c_api.TF_NewSessionOptions();
+ c_api.TF_SetTarget(Handle, target);
if (config != null)
SetConfig(config);
}
- public SessionOptions(IntPtr handle)
- {
- _handle = handle;
- }
-
- protected override void DisposeUnmanagedResources(IntPtr handle)
- => c_api.TF_DeleteSessionOptions(handle);
+ public void Dispose()
+ => Handle.Dispose();
- private void SetConfig(ConfigProto config)
+ private unsafe void SetConfig(ConfigProto config)
{
var bytes = config.ToByteArray();
- var proto = Marshal.AllocHGlobal(bytes.Length);
- Marshal.Copy(bytes, 0, proto, bytes.Length);
- using (var status = new Status())
+ fixed (byte* proto2 = bytes)
{
- c_api.TF_SetConfig(_handle, proto, (ulong)bytes.Length, status.Handle);
- status.Check(false);
+ using (var status = new Status())
+ {
+ c_api.TF_SetConfig(Handle, (IntPtr)proto2, (ulong)bytes.Length, status.Handle);
+ status.Check(false);
+ }
}
-
- Marshal.FreeHGlobal(proto);
}
-
- public static implicit operator IntPtr(SessionOptions opts) => opts._handle;
- public static implicit operator SessionOptions(IntPtr handle) => new SessionOptions(handle);
}
}
diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs
index 7082c617..8ac4d53e 100644
--- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs
+++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs
@@ -50,14 +50,14 @@ namespace Tensorflow
/// TF_Status*
/// TF_Session*
[DllImport(TensorFlowLibName)]
- public static extern IntPtr TF_NewSession(IntPtr graph, IntPtr opts, SafeStatusHandle status);
+ public static extern IntPtr TF_NewSession(IntPtr graph, SafeSessionOptionsHandle opts, SafeStatusHandle status);
///
/// Return a new options object.
///
/// TF_SessionOptions*
[DllImport(TensorFlowLibName)]
- public static extern unsafe IntPtr TF_NewSessionOptions();
+ public static extern SafeSessionOptionsHandle TF_NewSessionOptions();
///
/// Run the graph associated with the session starting with the supplied inputs
@@ -116,9 +116,9 @@ namespace Tensorflow
/// size_t
/// TF_Status*
[DllImport(TensorFlowLibName)]
- public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, SafeStatusHandle status);
+ public static extern void TF_SetConfig(SafeSessionOptionsHandle options, IntPtr proto, ulong proto_len, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
- public static extern void TF_SetTarget(IntPtr options, string target);
+ public static extern void TF_SetTarget(SafeSessionOptionsHandle options, string target);
}
}