Browse Source

fix constant_value when referenced by ndarray.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
efb1c242e1
7 changed files with 18 additions and 45 deletions
  1. +0
    -2
      src/TensorFlowNET.Core/DisposableObject.cs
  2. +0
    -6
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  3. +5
    -19
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  4. +1
    -7
      src/TensorFlowNET.Core/Numpy/NDArray.cs
  5. +9
    -4
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +3
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  7. +0
    -6
      test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs

+ 0
- 2
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -29,14 +29,12 @@ namespace Tensorflow
protected IntPtr _handle; protected IntPtr _handle;
protected bool _disposed; protected bool _disposed;


[SuppressMessage("ReSharper", "UnusedMember.Global")]
protected DisposableObject() protected DisposableObject()
{ } { }


protected DisposableObject(IntPtr handle) protected DisposableObject(IntPtr handle)
=> _handle = handle; => _handle = handle;


[SuppressMessage("ReSharper", "InvertIf")]
private void Dispose(bool disposing) private void Dispose(bool disposing)
{ {
if (_disposed) if (_disposed)


+ 0
- 6
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -94,11 +94,5 @@ namespace Tensorflow.Eager
// c_api.TF_GraphSetOutputHandleShapesAndTypes(target_t.graph, target_t._as_tf_output(), 0, new IntPtr[0], new int[0], new DataType[0], tf.Status.Handle); // c_api.TF_GraphSetOutputHandleShapesAndTypes(target_t.graph, target_t._as_tf_output(), 0, new IntPtr[0], new int[0], new DataType[0], tf.Status.Handle);
} }
} }

protected override void DisposeUnmanagedResources(IntPtr handle)
{
base.DisposeUnmanagedResources(handle);
_eagerTensorHandle.Dispose();
}
} }
} }

+ 5
- 19
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -43,44 +43,30 @@ namespace Tensorflow.NumPy
double val => new Tensor(val), double val => new Tensor(val),
_ => throw new NotImplementedException("") _ => throw new NotImplementedException("")
}; };
_tensor.SetReferencedByNDArray();


var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
_tensor.SetEagerTensorHandle(_handle);
_tensor.SetReferencedByNDArray();
} }


void Init(Array value, Shape? shape = null) void Init(Array value, Shape? shape = null)
{ {
_tensor = new Tensor(value, shape ?? value.GetShape()); _tensor = new Tensor(value, shape ?? value.GetShape());
_tensor.SetReferencedByNDArray(); _tensor.SetReferencedByNDArray();

var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
_tensor.SetEagerTensorHandle(_handle);
} }


void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE) void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
{ {
_tensor = new Tensor(shape, dtype: dtype); _tensor = new Tensor(shape, dtype: dtype);
_tensor.SetReferencedByNDArray(); _tensor.SetReferencedByNDArray();

var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
_tensor.SetEagerTensorHandle(_handle);
} }


void Init(Tensor value, Shape? shape = null) void Init(Tensor value, Shape? shape = null)
{ {
if (shape is not null)
_tensor = new Tensor(value.TensorDataPointer, shape, value.dtype);
else
_tensor = value;

if (_tensor.TensorDataPointer == IntPtr.Zero)
_tensor = tf.get_default_session().eval(_tensor);
// created tensor in graph mode
if (value.TensorDataPointer == IntPtr.Zero)
value = tf.defaultSession.eval(value);


_tensor = new Tensor(value.TensorDataPointer, shape ?? value.shape, value.dtype);
_tensor.SetReferencedByNDArray(); _tensor.SetReferencedByNDArray();

var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
_tensor.SetEagerTensorHandle(_handle);
} }
} }
} }

+ 1
- 7
src/TensorFlowNET.Core/Numpy/NDArray.cs View File

@@ -6,7 +6,7 @@ using static Tensorflow.Binding;


namespace Tensorflow.NumPy namespace Tensorflow.NumPy
{ {
public partial class NDArray : DisposableObject
public partial class NDArray
{ {
Tensor _tensor; Tensor _tensor;
public TF_DataType dtype => _tensor.dtype; public TF_DataType dtype => _tensor.dtype;
@@ -58,11 +58,5 @@ namespace Tensorflow.NumPy
{ {
return tensor_util.to_numpy_string(_tensor); return tensor_util.to_numpy_string(_tensor);
} }

protected override void DisposeUnmanagedResources(IntPtr handle)
{
_tensor.EagerTensorHandle.Dispose();
_tensor.Dispose();
}
} }
} }

+ 9
- 4
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -212,9 +212,12 @@ namespace Tensorflow
return _tf_output.Value; return _tf_output.Value;
} }


public void SetReferencedByNDArray() => isReferencedByNDArray = true;
public void SetEagerTensorHandle(SafeTensorHandleHandle handle) => _eagerTensorHandle = handle;

public void SetReferencedByNDArray()
{
isReferencedByNDArray = true;
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
}
public Tensor MaybeMove() public Tensor MaybeMove()
{ {
var tensor = c_api.TF_TensorMaybeMove(_handle); var tensor = c_api.TF_TensorMaybeMove(_handle);
@@ -256,7 +259,6 @@ namespace Tensorflow
} }
} }


[SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")]
protected override void DisposeUnmanagedResources(IntPtr handle) protected override void DisposeUnmanagedResources(IntPtr handle)
{ {
if (dtype == TF_DataType.TF_STRING) if (dtype == TF_DataType.TF_STRING)
@@ -274,6 +276,9 @@ namespace Tensorflow
} }


c_api.TF_DeleteTensor(handle); c_api.TF_DeleteTensor(handle);

if (_eagerTensorHandle is not null)
_eagerTensorHandle.Dispose();
} }


public bool IsDisposed => _disposed; public bool IsDisposed => _disposed;


+ 3
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -35,7 +35,9 @@ namespace Tensorflow
/// <returns></returns> /// <returns></returns>
public static NDArray constant_value(Tensor tensor, bool partial = false) public static NDArray constant_value(Tensor tensor, bool partial = false)
{ {
if (tensor is EagerTensor)
if (tensor.IsReferencedByNDArray)
return new NDArray(tensor);
else if (tensor is EagerTensor)
return tensor.numpy(); return tensor.numpy();


NDArray ret = _ConstantValue(tensor, partial); NDArray ret = _ConstantValue(tensor, partial);


+ 0
- 6
test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs View File

@@ -10,11 +10,5 @@ namespace TensorFlowNET.UnitTest
{ {
tf.compat.v1.disable_eager_execution(); tf.compat.v1.disable_eager_execution();
} }

[TestCleanup]
public void TestClean()
{
tf.enable_eager_execution();
}
} }
} }

Loading…
Cancel
Save