diff --git a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
index a512fba9..3ab800f6 100644
--- a/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
@@ -10,7 +10,7 @@ namespace Tensorflow.Eager
public EagerTensor(SafeTensorHandleHandle handle)
{
_id = ops.uid();
- EagerTensorHandle = handle;
+ _eagerTensorHandle = handle;
Resolve();
}
@@ -59,20 +59,14 @@ namespace Tensorflow.Eager
void NewEagerTensorHandle(IntPtr h)
{
_id = ops.uid();
- EagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle);
+ _eagerTensorHandle = c_api.TFE_NewTensorHandle(h, tf.Status.Handle);
tf.Status.Check(true);
-#if TRACK_TENSOR_LIFE
- print($"New EagerTensorHandle {EagerTensorHandle} {Id} From 0x{h.ToString("x16")}");
-#endif
}
private void Resolve()
{
- _handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.Status.Handle);
+ _handle = c_api.TFE_TensorHandleResolve(_eagerTensorHandle, tf.Status.Handle);
tf.Status.Check(true);
-#if TRACK_TENSOR_LIFE
- print($"Take EagerTensorHandle {EagerTensorHandle} {Id} Resolving 0x{_handle.ToString("x16")}");
-#endif
}
///
@@ -104,7 +98,7 @@ namespace Tensorflow.Eager
protected override void DisposeUnmanagedResources(IntPtr handle)
{
base.DisposeUnmanagedResources(handle);
- EagerTensorHandle.Dispose();
+ _eagerTensorHandle.Dispose();
}
}
}
diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
index 12213857..303d01d8 100644
--- a/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
+++ b/src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
@@ -25,6 +25,7 @@ namespace Tensorflow.NumPy
bool val => new NDArray(val),
byte val => new NDArray(val),
int val => new NDArray(val),
+ long val => new NDArray(val),
float val => new NDArray(val),
double val => new NDArray(val),
_ => throw new NotImplementedException("")
@@ -32,26 +33,44 @@ namespace Tensorflow.NumPy
void Init(T value) where T : unmanaged
{
- _tensor = new EagerTensor(value);
+ _tensor = value switch
+ {
+ bool val => new Tensor(val),
+ byte val => new Tensor(val),
+ int val => new Tensor(val),
+ long val => new Tensor(val),
+ float val => new Tensor(val),
+ double val => new Tensor(val),
+ _ => throw new NotImplementedException("")
+ };
_tensor.SetReferencedByNDArray();
+
+ var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
+ _tensor.SetEagerTensorHandle(_handle);
}
void Init(Array value, Shape? shape = null)
{
- _tensor = new EagerTensor(value, shape ?? value.GetShape());
+ _tensor = new Tensor(value, shape ?? value.GetShape());
_tensor.SetReferencedByNDArray();
+
+ var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
+ _tensor.SetEagerTensorHandle(_handle);
}
void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
{
- _tensor = new EagerTensor(shape, dtype: dtype);
+ _tensor = new Tensor(shape, dtype: dtype);
_tensor.SetReferencedByNDArray();
+
+ var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
+ _tensor.SetEagerTensorHandle(_handle);
}
void Init(Tensor value, Shape? shape = null)
{
if (shape is not null)
- _tensor = tf.reshape(value, shape);
+ _tensor = new Tensor(value.TensorDataPointer, shape, value.dtype);
else
_tensor = value;
@@ -59,6 +78,9 @@ namespace Tensorflow.NumPy
_tensor = tf.get_default_session().eval(_tensor);
_tensor.SetReferencedByNDArray();
+
+ var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
+ _tensor.SetEagerTensorHandle(_handle);
}
}
}
diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs
index 1cfc9b4e..f06b8366 100644
--- a/src/TensorFlowNET.Core/Numpy/NDArray.cs
+++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs
@@ -6,7 +6,7 @@ using static Tensorflow.Binding;
namespace Tensorflow.NumPy
{
- public partial class NDArray
+ public partial class NDArray : DisposableObject
{
Tensor _tensor;
public TF_DataType dtype => _tensor.dtype;
@@ -58,5 +58,11 @@ namespace Tensorflow.NumPy
{
return tensor_util.to_numpy_string(_tensor);
}
+
+ protected override void DisposeUnmanagedResources(IntPtr handle)
+ {
+ _tensor.EagerTensorHandle.Dispose();
+ _tensor.Dispose();
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
index 49847856..1f839ee7 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
@@ -53,10 +53,9 @@ namespace Tensorflow
/// Pointer to unmanaged, fixed or pinned memory which the caller owns
/// Tensor shape
/// TF data type
- /// Size of the tensor in memory
- public Tensor(IntPtr data_ptr, long[] shape, TF_DataType dType, int num_bytes)
+ public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype)
{
- _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: data_ptr, len: (ulong)num_bytes);
+ _handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer());
isCreatedInGraphMode = !tf.executing_eagerly();
}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index f0dd4274..bf52d13c 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -89,10 +89,11 @@ namespace Tensorflow
///
public object Tag { get; set; }
+ protected SafeTensorHandleHandle _eagerTensorHandle;
///
/// TFE_TensorHandle
///
- public SafeTensorHandleHandle EagerTensorHandle { get; set; }
+ public SafeTensorHandleHandle EagerTensorHandle => _eagerTensorHandle;
protected bool isReferencedByNDArray;
public bool IsReferencedByNDArray => isReferencedByNDArray;
@@ -212,6 +213,7 @@ namespace Tensorflow
}
public void SetReferencedByNDArray() => isReferencedByNDArray = true;
+ public void SetEagerTensorHandle(SafeTensorHandleHandle handle) => _eagerTensorHandle = handle;
public Tensor MaybeMove()
{
@@ -254,30 +256,16 @@ namespace Tensorflow
}
}
- ///
- /// Dispose any managed resources.
- ///
- /// Equivalent to what you would perform inside
- protected override void DisposeManagedResources()
- {
-
- }
-
[SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")]
protected override void DisposeUnmanagedResources(IntPtr handle)
{
-#if TRACK_TENSOR_LIFE
- print($"Delete Tensor 0x{handle.ToString("x16")} {AllocationType} Data: 0x{TensorDataPointer.ToString("x16")}");
-#endif
if (dtype == TF_DataType.TF_STRING)
{
long size = 1;
foreach (var s in TensorShape.dims)
size *= s;
var tstr = TensorDataPointer;
-#if TRACK_TENSOR_LIFE
- print($"Delete TString 0x{handle.ToString("x16")} {AllocationType} Data: 0x{tstr.ToString("x16")}");
-#endif
+
for (int i = 0; i < size; i++)
{
c_api.TF_StringDealloc(tstr);
diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
index b0390f5b..66b5fd3b 100644
--- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs
@@ -101,7 +101,7 @@ namespace Tensorflow
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int num_dims, IntPtr data, ulong len)
{
- return c_api.TF_NewTensor(dataType, dims, num_dims, data, len, EmptyDeallocator, DeallocatorArgs.Empty);
+ return TF_NewTensor(dataType, dims, num_dims, data, len, EmptyDeallocator, DeallocatorArgs.Empty);
}
public static unsafe IntPtr TF_NewTensor(Shape shape, TF_DataType dtype, void* data)
diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs
index 2d38f9b7..a8870252 100644
--- a/src/TensorFlowNET.Core/Tensors/constant_op.cs
+++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs
@@ -101,7 +101,7 @@ namespace Tensorflow
return op.outputs[0];
}
- private static Tensor _eager_reshape(EagerTensor tensor, int[] shape, Context ctx)
+ private static Tensor _eager_reshape(Tensor tensor, int[] shape, Context ctx)
{
var attr_t = tensor.dtype.as_datatype_enum();
var dims_t = convert_to_eager_tensor(shape, ctx, dtypes.int32);
@@ -111,7 +111,7 @@ namespace Tensorflow
return result[0];
}
- private static Tensor _eager_fill(int[] dims, EagerTensor value, Context ctx)
+ private static Tensor _eager_fill(int[] dims, Tensor value, Context ctx)
{
var attr_t = value.dtype.as_datatype_enum();
var dims_t = convert_to_eager_tensor(dims, ctx, dtypes.int32);
@@ -121,7 +121,7 @@ namespace Tensorflow
return result[0];
}
- private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid)
+ private static Tensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid)
{
ctx.ensure_initialized();
// convert data type
@@ -161,7 +161,7 @@ namespace Tensorflow
case EagerTensor val:
return val;
case NDArray val:
- return (EagerTensor)val;
+ return val;
case Shape val:
return new EagerTensor(val.dims, new Shape(val.ndim));
case TensorShape val:
diff --git a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
index 72edda0a..dc588a1a 100644
--- a/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
+++ b/test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs
@@ -18,7 +18,7 @@ namespace Tensorflow.Native.UnitTest.Tensors
var span = new Span(array, 100, 500);
fixed (float* ptr = &MemoryMarshal.GetReference(span))
{
- using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32, 4 * span.Length))
+ using (var t = new Tensor((IntPtr)ptr, new long[] { span.Length }, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(2000, (int)t.bytesize);
@@ -27,7 +27,7 @@ namespace Tensorflow.Native.UnitTest.Tensors
fixed (float* ptr = &array[0])
{
- using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32, 4 * array.Length))
+ using (var t = new Tensor((IntPtr)ptr, new long[] { array.Length }, tf.float32))
{
Assert.IsFalse(t.IsDisposed);
Assert.AreEqual(4000, (int)t.bytesize);
diff --git a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
index bd25736c..13c5b141 100644
--- a/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
+++ b/test/TensorFlowNET.UnitTest/EagerModeTestBase.cs
@@ -14,11 +14,6 @@ namespace TensorFlowNET.UnitTest
tf.Context.ensure_initialized();
}
- [TestCleanup]
- public void TestClean()
- {
- }
-
public bool Equal(float[] f1, float[] f2)
{
bool ret = false;