Browse Source

change Session(ConfigProto).

tags/v0.12
Oceania2018 6 years ago
parent
commit
6931d5c425
6 changed files with 21 additions and 50 deletions
  1. +8
    -7
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Sessions/Session.cs
  3. +8
    -5
      src/TensorFlowNET.Core/Sessions/SessionOptions.cs
  4. +0
    -26
      src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs
  5. +0
    -10
      src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs
  6. +4
    -1
      src/TensorFlowNET.Core/Sessions/c_api.session.cs

+ 8
- 7
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -36,19 +36,20 @@ namespace Tensorflow
protected byte[] _target;
public Graph graph => _graph;

public BaseSession(string target = "", Graph g = null, SessionOptions opts = null, Status status = null)
public BaseSession(string target = "", Graph g = null, ConfigProto config = null, Status status = null)
{
_graph = g ?? ops.get_default_graph();
_graph.as_default();
_target = Encoding.UTF8.GetBytes(target);

SessionOptions lopts = opts ?? new SessionOptions();

lock (Locks.ProcessWide)
using (var opts = new SessionOptions(target, config))
{
status = status ?? new Status();
_handle = c_api.TF_NewSession(_graph, opts ?? lopts, status);
status.Check(true);
lock (Locks.ProcessWide)
{
status = status ?? new Status();
_handle = c_api.TF_NewSession(_graph, opts, status);
status.Check(true);
}
}
}



+ 1
- 1
src/TensorFlowNET.Core/Sessions/Session.cs View File

@@ -32,7 +32,7 @@ namespace Tensorflow
_handle = handle;
}

public Session(Graph g, SessionOptions opts = null, Status s = null) : base("", g, opts, s)
public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g, config, s)
{ }

public Session as_default()


+ 8
- 5
src/TensorFlowNET.Core/Sessions/SessionOptions.cs View File

@@ -20,11 +20,14 @@ using System.Runtime.InteropServices;

namespace Tensorflow
{
public class SessionOptions : DisposableObject
internal class SessionOptions : DisposableObject
{
public SessionOptions()
public SessionOptions(string target = "", ConfigProto config = null)
{
_handle = c_api.TF_NewSessionOptions();
c_api.TF_SetTarget(_handle, target);
if (config != null)
SetConfig(config);
}

public SessionOptions(IntPtr handle)
@@ -35,10 +38,10 @@ namespace Tensorflow
protected override void DisposeUnmanagedResources(IntPtr handle)
=> c_api.TF_DeleteSessionOptions(handle);

public void SetConfig(ConfigProto config)
private void SetConfig(ConfigProto config)
{
var bytes = config.ToByteArray(); //TODO! we can use WriteTo
var proto = Marshal.AllocHGlobal(bytes.Length); //TODO! potential memory leak
var bytes = config.ToByteArray();
var proto = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, proto, bytes.Length);

using (var status = new Status())


+ 0
- 26
src/TensorFlowNET.Core/Sessions/TF_DeprecatedSession.cs View File

@@ -1,26 +0,0 @@
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
******************************************************************************/

using System.Runtime.InteropServices;

namespace Tensorflow.Sessions
{
[StructLayout(LayoutKind.Sequential)]
public struct TF_DeprecatedSession
{
Session session;
}
}

+ 0
- 10
src/TensorFlowNET.Core/Sessions/TF_SessionOptions.cs View File

@@ -1,10 +0,0 @@
using System.Runtime.InteropServices;

namespace Tensorflow
{
[StructLayout(LayoutKind.Sequential)]
public struct TF_SessionOptions
{
public SessionOptions options;
}
}

+ 4
- 1
src/TensorFlowNET.Core/Sessions/c_api.session.cs View File

@@ -116,6 +116,9 @@ namespace Tensorflow
/// <param name="proto_len">size_t</param>
/// <param name="status">TF_Status*</param>
[DllImport(TensorFlowLibName)]
public static extern unsafe void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status);
public static extern void TF_SetConfig(IntPtr options, IntPtr proto, ulong proto_len, IntPtr status);

[DllImport(TensorFlowLibName)]
public static extern void TF_SetTarget(IntPtr options, string target);
}
}

Loading…
Cancel
Save