From 6931d5c42592ee86b610dab1d12fddd2560b8697 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Mon, 21 Oct 2019 06:03:22 -0500 Subject: [PATCH] change Session(ConfigProto). --- .../Sessions/BaseSession.cs | 15 ++++++----- src/TensorFlowNET.Core/Sessions/Session.cs | 2 +- .../Sessions/SessionOptions.cs | 13 ++++++---- .../Sessions/TF_DeprecatedSession.cs | 26 ------------------- .../Sessions/TF_SessionOptions.cs | 10 ------- .../Sessions/c_api.session.cs | 5 +++- 6 files changed, 21 insertions(+), 50 deletions(-) delete mode 100644 src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs delete mode 100644 src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 1701c625..bb37956c 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -36,19 +36,20 @@ namespace Tensorflow protected byte[] _target; public Graph graph => _graph; - public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null) + public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null) { _graph = g ?? ops.get_default_graph(); _graph.as_default(); _target = Encoding.UTF8.GetBytes(target); - SessionOptions lopts = opts ?? new SessionOptions(); - - lock (Locks.ProcessWide) + using (var opts = new SessionOptions(target, config)) { - status = status ?? new Status(); - _handle = c_api.TF_NewSession(_graph, opts ?? lopts, status); - status.Check(true); + lock (Locks.ProcessWide) + { + status = status ?? new Status(); + _handle = c_api.TF_NewSession(_graph, opts, status); + status.Check(true); + } } } diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index a89d94dc..caa669d3 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -32,7 +32,7 @@ namespace Tensorflow _handle = handle; } - public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s) + public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s) { } public Session as_default() diff --git a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs index 112543fe..0e64033c 100644 --- a/src/TensorFlowNET.Core/Sessions/SessionOptions.cs +++ b/src/TensorFlowNET.Core/Sessions/SessionOptions.cs @@ -20,11 +20,14 @@ using System.Runtime.InteropServices; namespace Tensorflow { - public class SessionOptions : DisposableObject + internal class SessionOptions : DisposableObject { - public SessionOptions() + public SessionOptions(string target = "", ConfigProto config = null) { _handle = c_api.TF_NewSessionOptions(); + c_api.TF_SetTarget(_handle, target); + if (config != null) + SetConfig(config); } public SessionOptions(IntPtr handle) @@ -35,10 +38,10 @@ namespace Tensorflow protected override void DisposeUnmanagedResources(IntPtr handle) => c_api.TF_DeleteSessionOptions(handle); - public void SetConfig(ConfigProto config) + private void SetConfig(ConfigProto config) { - var bytes = config.ToByteArray(); //TODO! we can use WriteTo - var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak + var bytes = config.ToByteArray(); + var proto = Marshal.AllocHGlobal(bytes.Length); Marshal.Copy(bytes, 0, proto, bytes.Length); using (var status = new Status()) diff --git a/src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs b/src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs deleted file mode 100644 index baab71a0..00000000 --- a/src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs +++ /dev/null @@ -1,26 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using System.Runtime.InteropServices; - -namespace Tensorflow.Sessions -{ - [StructLayout(LayoutKind.Sequential)] - public struct TF_DeprecatedSession - { - Session session; - } -} diff --git a/src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs b/src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs deleted file mode 100644 index 35b27f12..00000000 --- a/src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs +++ /dev/null @@ -1,10 +0,0 @@ -using System.Runtime.InteropServices; - -namespace Tensorflow -{ - [StructLayout(LayoutKind.Sequential)] - public struct TF_SessionOptions - { - public SessionOptions options; - } -} diff --git a/src/TensorFlowNET.Core/Sessions/c_api.session.cs b/src/TensorFlowNET.Core/Sessions/c_api.session.cs index 3ed60435..713d0d5f 100644 --- a/src/TensorFlowNET.Core/Sessions/c_api.session.cs +++ b/src/TensorFlowNET.Core/Sessions/c_api.session.cs @@ -116,6 +116,9 @@ namespace Tensorflow /// size_t /// TF_Status* [DllImport(TensorFlowLibName)] - public static extern unsafe void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status); + public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status); + + [DllImport(TensorFlowLibName)] + public static extern void TF_SetTarget(IntPtr options, string target); } }