| @@ -287,7 +287,7 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_LoadSessionFromSavedModel(IntPtr session_options, IntPtr run_options, | |||||
| public static extern IntPtr TF_LoadSessionFromSavedModel(SafeSessionOptionsHandle session_options, IntPtr run_options, | |||||
| string export_dir, string[] tags, int tags_len, | string export_dir, string[] tags, int tags_len, | ||||
| IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status); | IntPtr graph, ref TF_Buffer meta_graph_def, SafeStatusHandle status); | ||||
| @@ -47,7 +47,7 @@ namespace Tensorflow | |||||
| lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
| { | { | ||||
| status = status ?? new Status(); | status = status ?? new Status(); | ||||
| _handle = c_api.TF_NewSession(_graph, opts, status.Handle); | |||||
| _handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle); | |||||
| status.Check(true); | status.Check(true); | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,40 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public sealed class SafeSessionOptionsHandle : SafeTensorflowHandle | |||||
| { | |||||
| public SafeSessionOptionsHandle() | |||||
| { | |||||
| } | |||||
| public SafeSessionOptionsHandle(IntPtr handle) | |||||
| : base(handle) | |||||
| { | |||||
| } | |||||
| protected override bool ReleaseHandle() | |||||
| { | |||||
| c_api.TF_DeleteSessionOptions(handle); | |||||
| SetHandle(IntPtr.Zero); | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -55,7 +55,7 @@ namespace Tensorflow | |||||
| IntPtr sess; | IntPtr sess; | ||||
| try | try | ||||
| { | { | ||||
| sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
| sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, | |||||
| IntPtr.Zero, | IntPtr.Zero, | ||||
| path, | path, | ||||
| tags, | tags, | ||||
| @@ -66,7 +66,7 @@ namespace Tensorflow | |||||
| status.Check(true); | status.Check(true); | ||||
| } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | } catch (TensorflowException ex) when (ex.Message.Contains("Could not find SavedModel")) | ||||
| { | { | ||||
| sess = c_api.TF_LoadSessionFromSavedModel(opt, | |||||
| sess = c_api.TF_LoadSessionFromSavedModel(opt.Handle, | |||||
| IntPtr.Zero, | IntPtr.Zero, | ||||
| Path.GetFullPath(path), | Path.GetFullPath(path), | ||||
| tags, | tags, | ||||
| @@ -16,44 +16,36 @@ | |||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using System; | using System; | ||||
| using System.Runtime.InteropServices; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| internal class SessionOptions : DisposableObject | |||||
| internal sealed class SessionOptions : IDisposable | |||||
| { | { | ||||
| public SafeSessionOptionsHandle Handle { get; } | |||||
| public SessionOptions(string target = "", ConfigProto config = null) | public SessionOptions(string target = "", ConfigProto config = null) | ||||
| { | { | ||||
| _handle = c_api.TF_NewSessionOptions(); | |||||
| c_api.TF_SetTarget(_handle, target); | |||||
| Handle = c_api.TF_NewSessionOptions(); | |||||
| c_api.TF_SetTarget(Handle, target); | |||||
| if (config != null) | if (config != null) | ||||
| SetConfig(config); | SetConfig(config); | ||||
| } | } | ||||
| public SessionOptions(IntPtr handle) | |||||
| { | |||||
| _handle = handle; | |||||
| } | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| => c_api.TF_DeleteSessionOptions(handle); | |||||
| public void Dispose() | |||||
| => Handle.Dispose(); | |||||
| private void SetConfig(ConfigProto config) | |||||
| private unsafe void SetConfig(ConfigProto config) | |||||
| { | { | ||||
| var bytes = config.ToByteArray(); | var bytes = config.ToByteArray(); | ||||
| var proto = Marshal.AllocHGlobal(bytes.Length); | |||||
| Marshal.Copy(bytes, 0, proto, bytes.Length); | |||||
| using (var status = new Status()) | |||||
| fixed (byte* proto2 = bytes) | |||||
| { | { | ||||
| c_api.TF_SetConfig(_handle, proto, (ulong)bytes.Length, status.Handle); | |||||
| status.Check(false); | |||||
| using (var status = new Status()) | |||||
| { | |||||
| c_api.TF_SetConfig(Handle, (IntPtr)proto2, (ulong)bytes.Length, status.Handle); | |||||
| status.Check(false); | |||||
| } | |||||
| } | } | ||||
| Marshal.FreeHGlobal(proto); | |||||
| } | } | ||||
| public static implicit operator IntPtr(SessionOptions opts) => opts._handle; | |||||
| public static implicit operator SessionOptions(IntPtr handle) => new SessionOptions(handle); | |||||
| } | } | ||||
| } | } | ||||
| @@ -50,14 +50,14 @@ namespace Tensorflow | |||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| /// <returns>TF_Session*</returns> | /// <returns>TF_Session*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern IntPtr TF_NewSession(IntPtr graph, IntPtr opts, SafeStatusHandle status); | |||||
| public static extern IntPtr TF_NewSession(IntPtr graph, SafeSessionOptionsHandle opts, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// Return a new options object. | /// Return a new options object. | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns>TF_SessionOptions*</returns> | /// <returns>TF_SessionOptions*</returns> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern unsafe IntPtr TF_NewSessionOptions(); | |||||
| public static extern SafeSessionOptionsHandle TF_NewSessionOptions(); | |||||
| /// <summary> | /// <summary> | ||||
| /// Run the graph associated with the session starting with the supplied inputs | /// Run the graph associated with the session starting with the supplied inputs | ||||
| @@ -116,9 +116,9 @@ namespace Tensorflow | |||||
| /// <param name="proto_len">size_t</param> | /// <param name="proto_len">size_t</param> | ||||
| /// <param name="status">TF_Status*</param> | /// <param name="status">TF_Status*</param> | ||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, SafeStatusHandle status); | |||||
| public static extern void TF_SetConfig(SafeSessionOptionsHandle options, IntPtr proto, ulong proto_len, SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TF_SetTarget(IntPtr options, string target); | |||||
| public static extern void TF_SetTarget(SafeSessionOptionsHandle options, string target); | |||||
| } | } | ||||
| } | } | ||||