diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 888c4fea..971e5ddf 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -287,7 +287,7 @@ namespace Tensorflow /// TF_Status* /// [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_LoadSessionFromSavedModel(IntPtr session_options, IntPtr run_options, + public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, string export_dir, string[] tags, int tags_len, IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status); diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 8f8e3f9f..76d434b4 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -47,7 +47,7 @@ namespace Tensorflow lock (Locks.ProcessWide) { status = status ?? new Status(); - _handle = c_api.TF_NewSession(_graph, opts, status.Handle); + _handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle); status.Check(true); } } diff --git a/src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs b/src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs new file mode 100644 index 00000000..95109cce --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/SafeSessionOptionsHandle.cs @@ -0,0 +1,40 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using Tensorflow.Util; + +namespace Tensorflow +{ + public sealed class SafeSessionOptionsHandle : SafeTensorflowHandle + { + public SafeSessionOptionsHandle() + { + } + + public SafeSessionOptionsHandle(IntPtr handle) + : base(handle) + { + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteSessionOptions(handle); + SetHandle(IntPtr.Zero); + return true; + } + } +} diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index bda3acbd..6de2c606 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -55,7 +55,7 @@ namespace Tensorflow IntPtr sess; try { - sess = c_api.TF_LoadSessionFromSavedModel(opt, + sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, IntPtr.Zero, path, tags, @@ -66,7 +66,7 @@ namespace Tensorflow status.Check(true); } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) { - sess = c_api.TF_LoadSessionFromSavedModel(opt, + sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, IntPtr.Zero, Path.GetFullPath(path), tags, diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs index 56f13628..00923d14 100644 --- a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs +++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs @@ -16,44 +16,36 @@ using Google.Protobuf; using System; -using System.Runtime.InteropServices; namespace Tensorflow { - internal class SessionOptions : DisposableObject + internal sealed class SessionOptions : IDisposable { + public SafeSessionOptionsHandle Handle { get; } + public SessionOptions(string target = "", ConfigProto config = null) { - _handle = c_api.TF_NewSessionOptions(); - c_api.TF_SetTarget(_handle, target); + Handle = c_api.TF_NewSessionOptions(); + c_api.TF_SetTarget(Handle, target); if (config != null) SetConfig(config); } - public SessionOptions(IntPtr handle) - { - _handle = handle; - } - - protected override void DisposeUnmanagedResources(IntPtr handle) - => c_api.TF_DeleteSessionOptions(handle); + public void Dispose() + => Handle.Dispose(); - private void SetConfig(ConfigProto config) + private unsafe void SetConfig(ConfigProto config) { var bytes = config.ToByteArray(); - var proto = Marshal.AllocHGlobal(bytes.Length); - Marshal.Copy(bytes, 0, proto, bytes.Length); - using (var status = new Status()) + fixed (byte* proto2 = bytes) { - c_api.TF_SetConfig(_handle, proto, (ulong)bytes.Length, status.Handle); - status.Check(false); + using (var status = new Status()) + { + c_api.TF_SetConfig(Handle, (IntPtr)proto2, (ulong)bytes.Length, status.Handle); + status.Check(false); + } } - - Marshal.FreeHGlobal(proto); } - - public static implicit operator IntPtr(SessionOptions opts) => opts._handle; - public static implicit operator SessionOptions(IntPtr handle) => new SessionOptions(handle); } } diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 7082c617..8ac4d53e 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -50,14 +50,14 @@ namespace Tensorflow /// TF_Status* /// TF_Session* [DllImport(TensorFlowLibName)] - public static extern IntPtr TF_NewSession(IntPtr graph, IntPtr opts, SafeStatusHandle status); + public static extern IntPtr TF_NewSession(IntPtr graph, SafeSessionOptionsHandle opts, SafeStatusHandle status); /// /// Return a new options object. /// /// TF_SessionOptions* [DllImport(TensorFlowLibName)] - public static extern unsafe IntPtr TF_NewSessionOptions(); + public static extern SafeSessionOptionsHandle TF_NewSessionOptions(); /// /// Run the graph associated with the session starting with the supplied inputs @@ -116,9 +116,9 @@ namespace Tensorflow /// size_t /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, SafeStatusHandle status); + public static extern void TF_SetConfig(SafeSessionOptionsHandle options, IntPtr proto, ulong proto_len, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - public static extern void TF_SetTarget(IntPtr options, string target); + public static extern void TF_SetTarget(SafeSessionOptionsHandle options, string target); } }