Browse Source

Remove _isCreatedInGraphMode in Tensor

tags/TimeSeries
Oceania2018 4 years ago
parent
commit
67a70bf800
6 changed files with 8 additions and 19 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  2. +0
    -10
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  3. +0
    -4
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  4. +0
    -1
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  6. +4
    -2
      src/TensorFlowNET.Keras/Engine/Layer.cs

+ 3
- 1
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -2,7 +2,9 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Eager;
using Tensorflow.Graphs; using Tensorflow.Graphs;
using Tensorflow.NumPy;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using static Tensorflow.tensorflow; using static Tensorflow.tensorflow;


@@ -148,7 +150,7 @@ namespace Tensorflow.Functions
src_graph: _func_graph); src_graph: _func_graph);


var captures_from_forward = backwards_graph.external_captures var captures_from_forward = backwards_graph.external_captures
.Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph)
.Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph)
.ToArray(); .ToArray();
foreach(var capture in captures_from_forward) foreach(var capture in captures_from_forward)
{ {


+ 0
- 10
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -32,7 +32,6 @@ namespace Tensorflow


public Tensor() public Tensor()
{ {
_isCreatedInGraphMode = !tf.executing_eagerly();
} }


/// <summary> /// <summary>
@@ -44,8 +43,6 @@ namespace Tensorflow
_handle = handle; _handle = handle;
if (clone && handle != null) if (clone && handle != null)
_handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer()); _handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer());
_isCreatedInGraphMode = !tf.executing_eagerly();
} }


/// <summary> /// <summary>
@@ -59,13 +56,11 @@ namespace Tensorflow
public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype)
{ {
_handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer()); _handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer());
_isCreatedInGraphMode = !tf.executing_eagerly();
} }


public unsafe Tensor(NDArray nd) public unsafe Tensor(NDArray nd)
{ {
_handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer());
_isCreatedInGraphMode = !tf.executing_eagerly();
} }


#region scala #region scala
@@ -107,13 +102,11 @@ namespace Tensorflow
_value_index = value_index; _value_index = value_index;
_override_dtype = dtype; _override_dtype = dtype;
_id = ops.uid(); _id = ops.uid();
_isCreatedInGraphMode = !tf.executing_eagerly();
} }


protected unsafe void InitTensor(Shape shape, TF_DataType dtype) protected unsafe void InitTensor(Shape shape, TF_DataType dtype)
{ {
_handle = TF_NewTensor(shape, dtype, null); _handle = TF_NewTensor(shape, dtype, null);
_isCreatedInGraphMode = !tf.executing_eagerly();
} }


protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype)
@@ -122,13 +115,10 @@ namespace Tensorflow
_handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); _handle = StringTensor(new byte[][] { bytes }, Shape.Scalar);
else else
_handle = TF_NewTensor(bytes, shape, dtype); _handle = TF_NewTensor(bytes, shape, dtype);
_isCreatedInGraphMode = !tf.executing_eagerly();
} }


protected unsafe void InitTensor(Array array, Shape? shape = null) protected unsafe void InitTensor(Array array, Shape? shape = null)
{ {
_isCreatedInGraphMode = !tf.executing_eagerly();

shape = shape ?? array.GetShape(); shape = shape ?? array.GetShape();
var dtype = array.GetDataType(); var dtype = array.GetDataType();




+ 0
- 4
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -94,10 +94,6 @@ namespace Tensorflow
/// </summary> /// </summary>
public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle; public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle;


protected bool _isCreatedInGraphMode;
public bool IsCreatedInGraphMode => _isCreatedInGraphMode;

/// <summary> /// <summary>
/// Returns the shape of a tensor. /// Returns the shape of a tensor.
/// </summary> /// </summary>


+ 0
- 1
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -21,7 +21,6 @@ namespace Tensorflow
public Shape shape => items.First().shape; public Shape shape => items.First().shape;
public int rank => items.First().rank; public int rank => items.First().rank;
public Graph graph => items.First().graph; public Graph graph => items.First().graph;
public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode;
public bool IsList { get; set; } public bool IsList { get; set; }
public int Length => items.Count(); public int Length => items.Count();




+ 1
- 1
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -68,7 +68,7 @@ namespace Tensorflow
// when this object is garbage collected the deleter will be too. This // when this object is garbage collected the deleter will be too. This
// means ResourceVariables can be part of reference cycles without those // means ResourceVariables can be part of reference cycles without those
// cycles being uncollectable. // cycles being uncollectable.
if (!handle.IsCreatedInGraphMode)
if (handle is EagerTensor)
{ {
_handle = handle.EagerTensorHandle.DangerousGetHandle(); _handle = handle.EagerTensorHandle.DangerousGetHandle();
eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device);


+ 4
- 2
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -18,9 +18,11 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Threading; using System.Threading;
using Tensorflow.Eager;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Utils; using Tensorflow.Keras.Utils;
using Tensorflow.NumPy;
using Tensorflow.Train; using Tensorflow.Train;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -118,7 +120,7 @@ namespace Tensorflow.Keras.Engine
bool _in_functional_construction_mode(Tensors inputs) bool _in_functional_construction_mode(Tensors inputs)
{ {
return tf.Context.executing_eagerly() return tf.Context.executing_eagerly()
&& inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count();
&& inputs.Count(x => x is not EagerTensor && x is not NDArray) == inputs.Count();
} }


public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) public void SetConnectivityMetadata(Tensors inputs, Tensors outputs)
@@ -180,7 +182,7 @@ namespace Tensorflow.Keras.Engine
tf.init_scope(); tf.init_scope();


bool need_restore_mode = false; bool need_restore_mode = false;
if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function())
if (inputs.Any(x => x is EagerTensor) || tf.Context.is_build_function())
{ {
need_restore_mode = true; need_restore_mode = true;
tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); tf.Context.eager_mode(isFunc: tf.Context.is_build_function());


Loading…
Cancel
Save