Browse Source

Use a local Status variable

Using a local reference ensure that the Status object cannot be disposed before the Dispose. This way it's also possible to use an external Status instance instead of the static one, if needed.
pull/978/head
Superpiffer GitHub 2 years ago
parent
commit
67a2fcd628
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 20 deletions
  1. +9
    -20
      src/TensorFlowNET.Core/Sessions/BaseSession.cs

+ 9
- 20
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -30,6 +30,7 @@ namespace Tensorflow
public class BaseSession : DisposableObject
{
protected Graph _graph;
protected Status _status;
public Graph graph => _graph;

public BaseSession(IntPtr handle, Graph g)
@@ -48,9 +49,9 @@ namespace Tensorflow
}
using var opts = new SessionOptions(target, config);
status = status ?? tf.Status;
_handle = c_api.TF_NewSession(_graph, opts.Handle, status.Handle);
status.Check(true);
_status = status ?? tf.Status;
_handle = c_api.TF_NewSession(_graph, opts.Handle, _status.Handle);
_status.Check(true);
}

public virtual void run(Operation op, params FeedItem[] feed_dict)
@@ -217,8 +218,6 @@ namespace Tensorflow
// Ensure any changes to the graph are reflected in the runtime.
_extend_graph();

var status = tf.Status;

var output_values = fetch_list.Select(x => IntPtr.Zero).ToArray();

c_api.TF_SessionRun(_handle,
@@ -232,9 +231,9 @@ namespace Tensorflow
target_opers: target_list.Select(f => (IntPtr)f).ToArray(),
ntargets: target_list.Count,
run_metadata: IntPtr.Zero,
status: status.Handle);
status: _status.Handle);

status.Check(true);
_status.Check(true);

var result = new NDArray[fetch_list.Length];

@@ -246,8 +245,6 @@ namespace Tensorflow

public unsafe Tensor eval(Tensor tensor)
{
var status = tf.Status;

var output_values = new IntPtr[1];
var fetch_list = new[] { tensor._as_tf_output() };

@@ -262,9 +259,9 @@ namespace Tensorflow
target_opers: new IntPtr[0],
ntargets: 0,
run_metadata: IntPtr.Zero,
status: status.Handle);
status: _status.Handle);

status.Check(true);
_status.Check(true);

return new Tensor(new SafeTensorHandle(output_values[0]));
}
@@ -291,15 +288,7 @@ namespace Tensorflow
protected override void DisposeUnmanagedResources(IntPtr handle)
{
// c_api.TF_CloseSession(handle, tf.Status.Handle);
if (tf.Status == null || tf.Status.Handle.IsInvalid)
{
using var status = new Status();
c_api.TF_DeleteSession(handle, status.Handle);
}
else
{
c_api.TF_DeleteSession(handle, tf.Status.Handle);
}
c_api.TF_DeleteSession(handle, _status.Handle);
}
}
}

Loading…
Cancel
Save