diff --git a/src/TensorFlowNET.Core/DisposableObject.cs b/src/TensorFlowNET.Core/DisposableObject.cs index 4f021360..3c70739b 100644 --- a/src/TensorFlowNET.Core/DisposableObject.cs +++ b/src/TensorFlowNET.Core/DisposableObject.cs @@ -29,14 +29,12 @@ namespace Tensorflow protected IntPtr _handle; protected bool _disposed; - [SuppressMessage("ReSharper", "UnusedMember.Global")] protected DisposableObject() { } protected DisposableObject(IntPtr handle) => _handle = handle; - [SuppressMessage("ReSharper", "InvertIf")] private void Dispose(bool disposing) { if (_disposed) diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index 3ab800f6..81ae271d 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -94,11 +94,5 @@ namespace Tensorflow.Eager // c_api.TF_GraphSetOutputHandleShapesAndTypes(target_t.graph, target_t._as_tf_output(), 0, new IntPtr[0], new int[0], new DataType[0], tf.Status.Handle); } } - - protected override void DisposeUnmanagedResources(IntPtr handle) - { - base.DisposeUnmanagedResources(handle); - _eagerTensorHandle.Dispose(); - } } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs index 303d01d8..2f89bc5d 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs @@ -43,44 +43,30 @@ namespace Tensorflow.NumPy double val => new Tensor(val), _ => throw new NotImplementedException("") }; - _tensor.SetReferencedByNDArray(); - var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); - _tensor.SetEagerTensorHandle(_handle); + _tensor.SetReferencedByNDArray(); } void Init(Array value, Shape? shape = null) { _tensor = new Tensor(value, shape ?? value.GetShape()); _tensor.SetReferencedByNDArray(); - - var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); - _tensor.SetEagerTensorHandle(_handle); } void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) { _tensor = new Tensor(shape, dtype: dtype); _tensor.SetReferencedByNDArray(); - - var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); - _tensor.SetEagerTensorHandle(_handle); } void Init(Tensor value, Shape? shape = null) { - if (shape is not null) - _tensor = new Tensor(value.TensorDataPointer, shape, value.dtype); - else - _tensor = value; - - if (_tensor.TensorDataPointer == IntPtr.Zero) - _tensor = tf.get_default_session().eval(_tensor); + // created tensor in graph mode + if (value.TensorDataPointer == IntPtr.Zero) + value = tf.defaultSession.eval(value); + _tensor = new Tensor(value.TensorDataPointer, shape ?? value.shape, value.dtype); _tensor.SetReferencedByNDArray(); - - var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); - _tensor.SetEagerTensorHandle(_handle); } } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index f06b8366..1cfc9b4e 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -6,7 +6,7 @@ using static Tensorflow.Binding; namespace Tensorflow.NumPy { - public partial class NDArray : DisposableObject + public partial class NDArray { Tensor _tensor; public TF_DataType dtype => _tensor.dtype; @@ -58,11 +58,5 @@ namespace Tensorflow.NumPy { return tensor_util.to_numpy_string(_tensor); } - - protected override void DisposeUnmanagedResources(IntPtr handle) - { - _tensor.EagerTensorHandle.Dispose(); - _tensor.Dispose(); - } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index bf52d13c..6eebb523 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -212,9 +212,12 @@ namespace Tensorflow return _tf_output.Value; } - public void SetReferencedByNDArray() => isReferencedByNDArray = true; - public void SetEagerTensorHandle(SafeTensorHandleHandle handle) => _eagerTensorHandle = handle; - + public void SetReferencedByNDArray() + { + isReferencedByNDArray = true; + _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); + } + public Tensor MaybeMove() { var tensor = c_api.TF_TensorMaybeMove(_handle); @@ -256,7 +259,6 @@ namespace Tensorflow } } - [SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")] protected override void DisposeUnmanagedResources(IntPtr handle) { if (dtype == TF_DataType.TF_STRING) @@ -274,6 +276,9 @@ namespace Tensorflow } c_api.TF_DeleteTensor(handle); + + if (_eagerTensorHandle is not null) + _eagerTensorHandle.Dispose(); } public bool IsDisposed => _disposed; diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 0f168904..d97ea1da 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -35,7 +35,9 @@ namespace Tensorflow /// public static NDArray constant_value(Tensor tensor, bool partial = false) { - if (tensor is EagerTensor) + if (tensor.IsReferencedByNDArray) + return new NDArray(tensor); + else if (tensor is EagerTensor) return tensor.numpy(); NDArray ret = _ConstantValue(tensor, partial); diff --git a/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs b/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs index bb3910b9..e5c46e29 100644 --- a/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs +++ b/test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs @@ -10,11 +10,5 @@ namespace TensorFlowNET.UnitTest { tf.compat.v1.disable_eager_execution(); } - - [TestCleanup] - public void TestClean() - { - tf.enable_eager_execution(); - } } }