| @@ -29,14 +29,12 @@ namespace Tensorflow | |||||
| protected IntPtr _handle; | protected IntPtr _handle; | ||||
| protected bool _disposed; | protected bool _disposed; | ||||
| [SuppressMessage("ReSharper", "UnusedMember.Global")] | |||||
| protected DisposableObject() | protected DisposableObject() | ||||
| { } | { } | ||||
| protected DisposableObject(IntPtr handle) | protected DisposableObject(IntPtr handle) | ||||
| => _handle = handle; | => _handle = handle; | ||||
| [SuppressMessage("ReSharper", "InvertIf")] | |||||
| private void Dispose(bool disposing) | private void Dispose(bool disposing) | ||||
| { | { | ||||
| if (_disposed) | if (_disposed) | ||||
| @@ -94,11 +94,5 @@ namespace Tensorflow.Eager | |||||
| // c_api.TF_GraphSetOutputHandleShapesAndTypes(target_t.graph, target_t._as_tf_output(), 0, new IntPtr[0], new int[0], new DataType[0], tf.Status.Handle); | // c_api.TF_GraphSetOutputHandleShapesAndTypes(target_t.graph, target_t._as_tf_output(), 0, new IntPtr[0], new int[0], new DataType[0], tf.Status.Handle); | ||||
| } | } | ||||
| } | } | ||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | |||||
| base.DisposeUnmanagedResources(handle); | |||||
| _eagerTensorHandle.Dispose(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -43,44 +43,30 @@ namespace Tensorflow.NumPy | |||||
| double val => new Tensor(val), | double val => new Tensor(val), | ||||
| _ => throw new NotImplementedException("") | _ => throw new NotImplementedException("") | ||||
| }; | }; | ||||
| _tensor.SetReferencedByNDArray(); | |||||
| var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); | |||||
| _tensor.SetEagerTensorHandle(_handle); | |||||
| _tensor.SetReferencedByNDArray(); | |||||
| } | } | ||||
| void Init(Array value, Shape? shape = null) | void Init(Array value, Shape? shape = null) | ||||
| { | { | ||||
| _tensor = new Tensor(value, shape ?? value.GetShape()); | _tensor = new Tensor(value, shape ?? value.GetShape()); | ||||
| _tensor.SetReferencedByNDArray(); | _tensor.SetReferencedByNDArray(); | ||||
| var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); | |||||
| _tensor.SetEagerTensorHandle(_handle); | |||||
| } | } | ||||
| void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) | ||||
| { | { | ||||
| _tensor = new Tensor(shape, dtype: dtype); | _tensor = new Tensor(shape, dtype: dtype); | ||||
| _tensor.SetReferencedByNDArray(); | _tensor.SetReferencedByNDArray(); | ||||
| var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); | |||||
| _tensor.SetEagerTensorHandle(_handle); | |||||
| } | } | ||||
| void Init(Tensor value, Shape? shape = null) | void Init(Tensor value, Shape? shape = null) | ||||
| { | { | ||||
| if (shape is not null) | |||||
| _tensor = new Tensor(value.TensorDataPointer, shape, value.dtype); | |||||
| else | |||||
| _tensor = value; | |||||
| if (_tensor.TensorDataPointer == IntPtr.Zero) | |||||
| _tensor = tf.get_default_session().eval(_tensor); | |||||
| // created tensor in graph mode | |||||
| if (value.TensorDataPointer == IntPtr.Zero) | |||||
| value = tf.defaultSession.eval(value); | |||||
| _tensor = new Tensor(value.TensorDataPointer, shape ?? value.shape, value.dtype); | |||||
| _tensor.SetReferencedByNDArray(); | _tensor.SetReferencedByNDArray(); | ||||
| var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); | |||||
| _tensor.SetEagerTensorHandle(_handle); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -6,7 +6,7 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow.NumPy | namespace Tensorflow.NumPy | ||||
| { | { | ||||
| public partial class NDArray : DisposableObject | |||||
| public partial class NDArray | |||||
| { | { | ||||
| Tensor _tensor; | Tensor _tensor; | ||||
| public TF_DataType dtype => _tensor.dtype; | public TF_DataType dtype => _tensor.dtype; | ||||
| @@ -58,11 +58,5 @@ namespace Tensorflow.NumPy | |||||
| { | { | ||||
| return tensor_util.to_numpy_string(_tensor); | return tensor_util.to_numpy_string(_tensor); | ||||
| } | } | ||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | |||||
| { | |||||
| _tensor.EagerTensorHandle.Dispose(); | |||||
| _tensor.Dispose(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -212,9 +212,12 @@ namespace Tensorflow | |||||
| return _tf_output.Value; | return _tf_output.Value; | ||||
| } | } | ||||
| public void SetReferencedByNDArray() => isReferencedByNDArray = true; | |||||
| public void SetEagerTensorHandle(SafeTensorHandleHandle handle) => _eagerTensorHandle = handle; | |||||
| public void SetReferencedByNDArray() | |||||
| { | |||||
| isReferencedByNDArray = true; | |||||
| _eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle); | |||||
| } | |||||
| public Tensor MaybeMove() | public Tensor MaybeMove() | ||||
| { | { | ||||
| var tensor = c_api.TF_TensorMaybeMove(_handle); | var tensor = c_api.TF_TensorMaybeMove(_handle); | ||||
| @@ -256,7 +259,6 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| [SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")] | |||||
| protected override void DisposeUnmanagedResources(IntPtr handle) | protected override void DisposeUnmanagedResources(IntPtr handle) | ||||
| { | { | ||||
| if (dtype == TF_DataType.TF_STRING) | if (dtype == TF_DataType.TF_STRING) | ||||
| @@ -274,6 +276,9 @@ namespace Tensorflow | |||||
| } | } | ||||
| c_api.TF_DeleteTensor(handle); | c_api.TF_DeleteTensor(handle); | ||||
| if (_eagerTensorHandle is not null) | |||||
| _eagerTensorHandle.Dispose(); | |||||
| } | } | ||||
| public bool IsDisposed => _disposed; | public bool IsDisposed => _disposed; | ||||
| @@ -35,7 +35,9 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static NDArray constant_value(Tensor tensor, bool partial = false) | public static NDArray constant_value(Tensor tensor, bool partial = false) | ||||
| { | { | ||||
| if (tensor is EagerTensor) | |||||
| if (tensor.IsReferencedByNDArray) | |||||
| return new NDArray(tensor); | |||||
| else if (tensor is EagerTensor) | |||||
| return tensor.numpy(); | return tensor.numpy(); | ||||
| NDArray ret = _ConstantValue(tensor, partial); | NDArray ret = _ConstantValue(tensor, partial); | ||||
| @@ -10,11 +10,5 @@ namespace TensorFlowNET.UnitTest | |||||
| { | { | ||||
| tf.compat.v1.disable_eager_execution(); | tf.compat.v1.disable_eager_execution(); | ||||
| } | } | ||||
| [TestCleanup] | |||||
| public void TestClean() | |||||
| { | |||||
| tf.enable_eager_execution(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||