Browse Source

Remove _isCreatedInGraphMode in Tensor

tags/TimeSeries
Oceania2018 4 years ago
parent
commit
67a70bf800
6 changed files with 8 additions and 19 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  2. +0
    -10
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  3. +0
    -4
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  4. +0
    -1
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  6. +4
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs

+ 3
- 1
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -2,7 +2,9 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Eager;
using Tensorflow.Graphs;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.tensorflow;

@@ -148,7 +150,7 @@ namespace Tensorflow.Functions
src_graph: _func_graph);

var captures_from_forward = backwards_graph.external_captures
.Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph)
.Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph)
.ToArray();
foreach(var capture in captures_from_forward)
{


+ 0
- 10
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -32,7 +32,6 @@ namespace Tensorflow

public Tensor()
{
_isCreatedInGraphMode = !tf.executing_eagerly();
}

/// <summary>
@@ -44,8 +43,6 @@ namespace Tensorflow
_handle = handle;
if (clone && handle != null)
_handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer());
_isCreatedInGraphMode = !tf.executing_eagerly();
}

/// <summary>
@@ -59,13 +56,11 @@ namespace Tensorflow
public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype)
{
_handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer());
_isCreatedInGraphMode = !tf.executing_eagerly();
}

public unsafe Tensor(NDArray nd)
{
_handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer());
_isCreatedInGraphMode = !tf.executing_eagerly();
}

#region scala
@@ -107,13 +102,11 @@ namespace Tensorflow
_value_index = value_index;
_override_dtype = dtype;
_id = ops.uid();
_isCreatedInGraphMode = !tf.executing_eagerly();
}

protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
{
_handle = TF_NewTensor(shape, dtype, null);
_isCreatedInGraphMode = !tf.executing_eagerly();
}

protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype)
@@ -122,13 +115,10 @@ namespace Tensorflow
_handle = StringTensor(new byte[][] { bytes }, Shape.Scalar);
else
_handle = TF_NewTensor(bytes, shape, dtype);
_isCreatedInGraphMode = !tf.executing_eagerly();
}

protected unsafe void InitTensor(Array array, Shape? shape = null)
{
_isCreatedInGraphMode = !tf.executing_eagerly();

shape = shape ?? array.GetShape();
var dtype = array.GetDataType();



+ 0
- 4
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -94,10 +94,6 @@ namespace Tensorflow
/// </summary>
public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle;

protected bool _isCreatedInGraphMode;
public bool IsCreatedInGraphMode => _isCreatedInGraphMode;

/// <summary>
/// Returns the shape of a tensor.
/// </summary>


+ 0
- 1
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -21,7 +21,6 @@ namespace Tensorflow
public Shape shape => items.First().shape;
public int rank => items.First().rank;
public Graph graph => items.First().graph;
public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode;
public bool IsList { get; set; }
public int Length => items.Count();



+ 1
- 1
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -68,7 +68,7 @@ namespace Tensorflow
// when this object is garbage collected the deleter will be too. This
// means ResourceVariables can be part of reference cycles without those
// cycles being uncollectable.
if (!handle.IsCreatedInGraphMode)
if (handle is EagerTensor)
{
_handle = handle.EagerTensorHandle.DangerousGetHandle();
eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);


+ 4
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -18,9 +18,11 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using Tensorflow.Eager;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils;
using Tensorflow.NumPy;
using Tensorflow.Train;
using static Tensorflow.Binding;

@@ -118,7 +120,7 @@ namespace Tensorflow.Keras.Engine
bool _in_functional_construction_mode(Tensors inputs)
{
return tf.Context.executing_eagerly()
&& inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count();
&& inputs.Count(x => x is not EagerTensor && x is not NDArray) == inputs.Count();
}

public void SetConnectivityMetadata(Tensors inputs, Tensors outputs)
@@ -180,7 +182,7 @@ namespace Tensorflow.Keras.Engine
tf.init_scope();

bool need_restore_mode = false;
if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function())
if (inputs.Any(x => x is EagerTensor) || tf.Context.is_build_function())
{
need_restore_mode = true;
tf.Context.eager_mode(isFunc: tf.Context.is_build_function());


Loading…
Cancel
Save