diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs index a512fba9..3ab800f6 100644 --- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs +++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs @@ -10,7 +10,7 @@ namespace Tensorflow.Eager public EagerTensor(SafeTensorHandleHandle handle) { _id = ops.uid(); - EagerTensorHandle = handle; + _eagerTensorHandle = handle; Resolve(); } @@ -59,20 +59,14 @@ namespace Tensorflow.Eager void NewEagerTensorHandle(IntPtr h) { _id = ops.uid(); - EagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle); + _eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle); tf.Status.Check(true); -#if TRACK_TENSOR_LIFE - print($"New EagerTensorHandle {EagerTensorHandle} {Id} From 0x{h.ToString("x16")}"); -#endif } private void Resolve() { - _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.Status.Handle); + _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle); tf.Status.Check(true); -#if TRACK_TENSOR_LIFE - print($"Take EagerTensorHandle {EagerTensorHandle} {Id} Resolving 0x{_handle.ToString("x16")}"); -#endif } /// @@ -104,7 +98,7 @@ namespace Tensorflow.Eager protected override void DisposeUnmanagedResources(IntPtr handle) { base.DisposeUnmanagedResources(handle); - EagerTensorHandle.Dispose(); + _eagerTensorHandle.Dispose(); } } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs index 12213857..303d01d8 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs @@ -25,6 +25,7 @@ namespace Tensorflow.NumPy bool val => new NDArray(val), byte val => new NDArray(val), int val => new NDArray(val), + long val => new NDArray(val), float val => new NDArray(val), double val => new NDArray(val), _ => throw new NotImplementedException("") @@ -32,26 +33,44 @@ namespace Tensorflow.NumPy void Init(T value) where T : unmanaged { - _tensor = new EagerTensor(value); + _tensor = value switch + { + bool val => new Tensor(val), + byte val => new Tensor(val), + int val => new Tensor(val), + long val => new Tensor(val), + float val => new Tensor(val), + double val => new Tensor(val), + _ => throw new NotImplementedException("") + }; _tensor.SetReferencedByNDArray(); + + var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); + _tensor.SetEagerTensorHandle(_handle); } void Init(Array value, Shape? shape = null) { - _tensor = new EagerTensor(value, shape ?? value.GetShape()); + _tensor = new Tensor(value, shape ?? value.GetShape()); _tensor.SetReferencedByNDArray(); + + var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); + _tensor.SetEagerTensorHandle(_handle); } void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) { - _tensor = new EagerTensor(shape, dtype: dtype); + _tensor = new Tensor(shape, dtype: dtype); _tensor.SetReferencedByNDArray(); + + var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); + _tensor.SetEagerTensorHandle(_handle); } void Init(Tensor value, Shape? shape = null) { if (shape is not null) - _tensor = tf.reshape(value, shape); + _tensor = new Tensor(value.TensorDataPointer, shape, value.dtype); else _tensor = value; @@ -59,6 +78,9 @@ namespace Tensorflow.NumPy _tensor = tf.get_default_session().eval(_tensor); _tensor.SetReferencedByNDArray(); + + var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle); + _tensor.SetEagerTensorHandle(_handle); } } } diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs index 1cfc9b4e..f06b8366 100644 --- a/src/TensorFlowNET.Core/Numpy/NDArray.cs +++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs @@ -6,7 +6,7 @@ using static Tensorflow.Binding; namespace Tensorflow.NumPy { - public partial class NDArray + public partial class NDArray : DisposableObject { Tensor _tensor; public TF_DataType dtype => _tensor.dtype; @@ -58,5 +58,11 @@ namespace Tensorflow.NumPy { return tensor_util.to_numpy_string(_tensor); } + + protected override void DisposeUnmanagedResources(IntPtr handle) + { + _tensor.EagerTensorHandle.Dispose(); + _tensor.Dispose(); + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 49847856..1f839ee7 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -53,10 +53,9 @@ namespace Tensorflow /// Pointer to unmanaged, fixed or pinned memory which the caller owns /// Tensor shape /// TF data type - /// Size of the tensor in memory - public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes) + public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) { - _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes); + _handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer()); isCreatedInGraphMode = !tf.executing_eagerly(); } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index f0dd4274..bf52d13c 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -89,10 +89,11 @@ namespace Tensorflow /// public object Tag { get; set; } + protected SafeTensorHandleHandle _eagerTensorHandle; /// /// TFE_TensorHandle /// - public SafeTensorHandleHandle EagerTensorHandle { get; set; } + public SafeTensorHandleHandle EagerTensorHandle => _eagerTensorHandle; protected bool isReferencedByNDArray; public bool IsReferencedByNDArray => isReferencedByNDArray; @@ -212,6 +213,7 @@ namespace Tensorflow } public void SetReferencedByNDArray() => isReferencedByNDArray = true; + public void SetEagerTensorHandle(SafeTensorHandleHandle handle) => _eagerTensorHandle = handle; public Tensor MaybeMove() { @@ -254,30 +256,16 @@ namespace Tensorflow } } - /// - /// Dispose any managed resources. - /// - /// Equivalent to what you would perform inside - protected override void DisposeManagedResources() - { - - } - [SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")] protected override void DisposeUnmanagedResources(IntPtr handle) { -#if TRACK_TENSOR_LIFE - print($"Delete Tensor 0x{handle.ToString("x16")} {AllocationType} Data: 0x{TensorDataPointer.ToString("x16")}"); -#endif if (dtype == TF_DataType.TF_STRING) { long size = 1; foreach (var s in TensorShape.dims) size *= s; var tstr = TensorDataPointer; -#if TRACK_TENSOR_LIFE - print($"Delete TString 0x{handle.ToString("x16")} {AllocationType} Data: 0x{tstr.ToString("x16")}"); -#endif + for (int i = 0; i < size; i++) { c_api.TF_StringDealloc(tstr); diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index b0390f5b..66b5fd3b 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -101,7 +101,7 @@ namespace Tensorflow [MethodImpl(MethodImplOptions.AggressiveInlining)] public static unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len) { - return c_api.TF_NewTensor(dataType, dims, num_dims, data, len, EmptyDeallocator, DeallocatorArgs.Empty); + return TF_NewTensor(dataType, dims, num_dims, data, len, EmptyDeallocator, DeallocatorArgs.Empty); } public static unsafe IntPtr TF_NewTensor(Shape shape, TF_DataType dtype, void* data) diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs index 2d38f9b7..a8870252 100644 --- a/src/TensorFlowNET.Core/Tensors/constant_op.cs +++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs @@ -101,7 +101,7 @@ namespace Tensorflow return op.outputs[0]; } - private static Tensor _eager_reshape(EagerTensor tensor, int[] shape, Context ctx) + private static Tensor _eager_reshape(Tensor tensor, int[] shape, Context ctx) { var attr_t = tensor.dtype.as_datatype_enum(); var dims_t = convert_to_eager_tensor(shape, ctx, dtypes.int32); @@ -111,7 +111,7 @@ namespace Tensorflow return result[0]; } - private static Tensor _eager_fill(int[] dims, EagerTensor value, Context ctx) + private static Tensor _eager_fill(int[] dims, Tensor value, Context ctx) { var attr_t = value.dtype.as_datatype_enum(); var dims_t = convert_to_eager_tensor(dims, ctx, dtypes.int32); @@ -121,7 +121,7 @@ namespace Tensorflow return result[0]; } - private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid) + private static Tensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid) { ctx.ensure_initialized(); // convert data type @@ -161,7 +161,7 @@ namespace Tensorflow case EagerTensor val: return val; case NDArray val: - return (EagerTensor)val; + return val; case Shape val: return new EagerTensor(val.dims, new Shape(val.ndim)); case TensorShape val: diff --git a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs index 72edda0a..dc588a1a 100644 --- a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs +++ b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Native.UnitTest.Tensors var span = new Span(array, 100, 500); fixed (float* ptr = &MemoryMarshal.GetReference(span)) { - using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length)) + using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32)) { Assert.IsFalse(t.IsDisposed); Assert.AreEqual(2000, (int)t.bytesize); @@ -27,7 +27,7 @@ namespace Tensorflow.Native.UnitTest.Tensors fixed (float* ptr = &array[0]) { - using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length)) + using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32)) { Assert.IsFalse(t.IsDisposed); Assert.AreEqual(4000, (int)t.bytesize); diff --git a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs index bd25736c..13c5b141 100644 --- a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs +++ b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs @@ -14,11 +14,6 @@ namespace TensorFlowNET.UnitTest tf.Context.ensure_initialized(); } - [TestCleanup] - public void TestClean() - { - } - public bool Equal(float[] f1, float[] f2) { bool ret = false;