Browse Source

fix constant_value when referenced by ndarray.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
efb1c242e1
7 changed files with 18 additions and 45 deletions
  1. +0
    -2
      src/TensorFlowNET.Core/DisposableObject.cs
  2. +0
    -6
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  3. +5
    -19
      src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs
  4. +1
    -7
      src/TensorFlowNET.Core/Numpy/NDArray.cs
  5. +9
    -4
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  6. +3
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  7. +0
    -6
      test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs

+ 0
- 2
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -29,14 +29,12 @@ namespace Tensorflow
protected IntPtr _handle;
protected bool _disposed;

[SuppressMessage("ReSharper", "UnusedMember.Global")]
protected DisposableObject()
{ }

protected DisposableObject(IntPtr handle)
=> _handle = handle;

[SuppressMessage("ReSharper", "InvertIf")]
private void Dispose(bool disposing)
{
if (_disposed)


+ 0
- 6
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -94,11 +94,5 @@ namespace Tensorflow.Eager
// c_api.TF_GraphSetOutputHandleShapesAndTypes(target_t.graph, target_t._as_tf_output(), 0, new IntPtr[0], new int[0], new DataType[0], tf.Status.Handle);
}
}

protected override void DisposeUnmanagedResources(IntPtr handle)
{
base.DisposeUnmanagedResources(handle);
_eagerTensorHandle.Dispose();
}
}
}

+ 5
- 19
src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs View File

@@ -43,44 +43,30 @@ namespace Tensorflow.NumPy
double val => new Tensor(val),
_ => throw new NotImplementedException("")
};
_tensor.SetReferencedByNDArray();

var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
_tensor.SetEagerTensorHandle(_handle);
_tensor.SetReferencedByNDArray();
}

void Init(Array value, Shape? shape = null)
{
_tensor = new Tensor(value, shape ?? value.GetShape());
_tensor.SetReferencedByNDArray();

var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
_tensor.SetEagerTensorHandle(_handle);
}

void Init(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
{
_tensor = new Tensor(shape, dtype: dtype);
_tensor.SetReferencedByNDArray();

var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
_tensor.SetEagerTensorHandle(_handle);
}

void Init(Tensor value, Shape? shape = null)
{
if (shape is not null)
_tensor = new Tensor(value.TensorDataPointer, shape, value.dtype);
else
_tensor = value;

if (_tensor.TensorDataPointer == IntPtr.Zero)
_tensor = tf.get_default_session().eval(_tensor);
// created tensor in graph mode
if (value.TensorDataPointer == IntPtr.Zero)
value = tf.defaultSession.eval(value);

_tensor = new Tensor(value.TensorDataPointer, shape ?? value.shape, value.dtype);
_tensor.SetReferencedByNDArray();

var _handle = c_api.TFE_NewTensorHandle(_tensor, tf.Status.Handle);
_tensor.SetEagerTensorHandle(_handle);
}
}
}

+ 1
- 7
src/TensorFlowNET.Core/Numpy/NDArray.cs View File

@@ -6,7 +6,7 @@ using static Tensorflow.Binding;

namespace Tensorflow.NumPy
{
public partial class NDArray : DisposableObject
public partial class NDArray
{
Tensor _tensor;
public TF_DataType dtype => _tensor.dtype;
@@ -58,11 +58,5 @@ namespace Tensorflow.NumPy
{
return tensor_util.to_numpy_string(_tensor);
}

protected override void DisposeUnmanagedResources(IntPtr handle)
{
_tensor.EagerTensorHandle.Dispose();
_tensor.Dispose();
}
}
}

+ 9
- 4
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -212,9 +212,12 @@ namespace Tensorflow
return _tf_output.Value;
}

public void SetReferencedByNDArray() => isReferencedByNDArray = true;
public void SetEagerTensorHandle(SafeTensorHandleHandle handle) => _eagerTensorHandle = handle;

public void SetReferencedByNDArray()
{
isReferencedByNDArray = true;
_eagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
}
public Tensor MaybeMove()
{
var tensor = c_api.TF_TensorMaybeMove(_handle);
@@ -256,7 +259,6 @@ namespace Tensorflow
}
}

[SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")]
protected override void DisposeUnmanagedResources(IntPtr handle)
{
if (dtype == TF_DataType.TF_STRING)
@@ -274,6 +276,9 @@ namespace Tensorflow
}

c_api.TF_DeleteTensor(handle);

if (_eagerTensorHandle is not null)
_eagerTensorHandle.Dispose();
}

public bool IsDisposed => _disposed;


+ 3
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -35,7 +35,9 @@ namespace Tensorflow
/// <returns></returns>
public static NDArray constant_value(Tensor tensor, bool partial = false)
{
if (tensor is EagerTensor)
if (tensor.IsReferencedByNDArray)
return new NDArray(tensor);
else if (tensor is EagerTensor)
return tensor.numpy();

NDArray ret = _ConstantValue(tensor, partial);


+ 0
- 6
test/TensorFlowNET.Graph.UnitTest/GraphModeTestBase.cs View File

@@ -10,11 +10,5 @@ namespace TensorFlowNET.UnitTest
{
tf.compat.v1.disable_eager_execution();
}

[TestCleanup]
public void TestClean()
{
tf.enable_eager_execution();
}
}
}

Loading…
Cancel
Save