| @@ -287,7 +287,7 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns></returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_LoadSessionFromSavedModel(IntPtr session_options, IntPtr run_options, | |||
| public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | |||
| string export_dir, string[] tags, int tags_len, | |||
| IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status); | |||
| @@ -47,7 +47,7 @@ namespace Tensorflow | |||
| lock (Locks.ProcessWide) | |||
| { | |||
| status = status ?? new Status(); | |||
| _handle = c_api.TF_NewSession(_graph, opts, status.Handle); | |||
| _handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle); | |||
| status.Check(true); | |||
| } | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow | |||
| { | |||
| public sealed class SafeSessionOptionsHandle : SafeTensorflowHandle | |||
| { | |||
| public SafeSessionOptionsHandle() | |||
| { | |||
| } | |||
| public SafeSessionOptionsHandle(IntPtr handle) | |||
| : base(handle) | |||
| { | |||
| } | |||
| protected override bool ReleaseHandle() | |||
| { | |||
| c_api.TF_DeleteSessionOptions(handle); | |||
| SetHandle(IntPtr.Zero); | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| @@ -55,7 +55,7 @@ namespace Tensorflow | |||
| IntPtr sess; | |||
| try | |||
| { | |||
| sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
| sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, | |||
| IntPtr.Zero, | |||
| path, | |||
| tags, | |||
| @@ -66,7 +66,7 @@ namespace Tensorflow | |||
| status.Check(true); | |||
| } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | |||
| { | |||
| sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||
| sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, | |||
| IntPtr.Zero, | |||
| Path.GetFullPath(path), | |||
| tags, | |||
| @@ -16,44 +16,36 @@ | |||
| using Google.Protobuf; | |||
| using System; | |||
| using System.Runtime.InteropServices; | |||
| namespace Tensorflow | |||
| { | |||
| internal class SessionOptions : DisposableObject | |||
| internal sealed class SessionOptions : IDisposable | |||
| { | |||
| public SafeSessionOptionsHandle Handle { get; } | |||
| public SessionOptions(string target = "", ConfigProto config = null) | |||
| { | |||
| _handle = c_api.TF_NewSessionOptions(); | |||
| c_api.TF_SetTarget(_handle, target); | |||
| Handle = c_api.TF_NewSessionOptions(); | |||
| c_api.TF_SetTarget(Handle, target); | |||
| if (config != null) | |||
| SetConfig(config); | |||
| } | |||
| public SessionOptions(IntPtr handle) | |||
| { | |||
| _handle = handle; | |||
| } | |||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||
| => c_api.TF_DeleteSessionOptions(handle); | |||
| public void Dispose() | |||
| => Handle.Dispose(); | |||
| private void SetConfig(ConfigProto config) | |||
| private unsafe void SetConfig(ConfigProto config) | |||
| { | |||
| var bytes = config.ToByteArray(); | |||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||
| using (var status = new Status()) | |||
| fixed (byte* proto2 = bytes) | |||
| { | |||
| c_api.TF_SetConfig(_handle, proto, (ulong)bytes.Length, status.Handle); | |||
| status.Check(false); | |||
| using (var status = new Status()) | |||
| { | |||
| c_api.TF_SetConfig(Handle, (IntPtr)proto2, (ulong)bytes.Length, status.Handle); | |||
| status.Check(false); | |||
| } | |||
| } | |||
| Marshal.FreeHGlobal(proto); | |||
| } | |||
| public static implicit operator IntPtr(SessionOptions opts) => opts._handle; | |||
| public static implicit operator SessionOptions(IntPtr handle) => new SessionOptions(handle); | |||
| } | |||
| } | |||
| @@ -50,14 +50,14 @@ namespace Tensorflow | |||
| /// <param name="status">TF_Status*</param> | |||
| /// <returns>TF_Session*</returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern IntPtr TF_NewSession(IntPtr graph, IntPtr opts, SafeStatusHandle status); | |||
| public static extern IntPtr TF_NewSession(IntPtr graph, SafeSessionOptionsHandle opts, SafeStatusHandle status); | |||
| /// <summary> | |||
| /// Return a new options object. | |||
| /// </summary> | |||
| /// <returns>TF_SessionOptions*</returns> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern unsafe IntPtr TF_NewSessionOptions(); | |||
| public static extern SafeSessionOptionsHandle TF_NewSessionOptions(); | |||
| /// <summary> | |||
| /// Run the graph associated with the session starting with the supplied inputs | |||
| @@ -116,9 +116,9 @@ namespace Tensorflow | |||
| /// <param name="proto_len">size_t</param> | |||
| /// <param name="status">TF_Status*</param> | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, SafeStatusHandle status); | |||
| public static extern void TF_SetConfig(SafeSessionOptionsHandle options, IntPtr proto, ulong proto_len, SafeStatusHandle status); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern void TF_SetTarget(IntPtr options, string target); | |||
| public static extern void TF_SetTarget(SafeSessionOptionsHandle options, string target); | |||
| } | |||
| } | |||