From 67a70bf800da6172e8e5ae51dc868557caf7a387 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 28 Nov 2021 16:51:10 -0600 Subject: [PATCH] Remove _isCreatedInGraphMode in Tensor --- .../Functions/TapeGradientFunctions.cs | 4 +++- src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs | 10 ---------- src/TensorFlowNET.Core/Tensors/Tensor.cs | 4 ---- src/TensorFlowNET.Core/Tensors/Tensors.cs | 1 - .../Variables/BaseResourceVariable.cs | 2 +- src/TensorFlowNET.Keras/Engine/Layer.cs | 6 ++++-- 6 files changed, 8 insertions(+), 19 deletions(-) diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index b4241304..9f216ff7 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -2,7 +2,9 @@ using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Eager; using Tensorflow.Graphs; +using Tensorflow.NumPy; using static Tensorflow.Binding; using static Tensorflow.tensorflow; @@ -148,7 +150,7 @@ namespace Tensorflow.Functions src_graph: _func_graph); var captures_from_forward = backwards_graph.external_captures - .Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph) + .Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph) .ToArray(); foreach(var capture in captures_from_forward) { diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 0e460bd3..79b8d2c5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -32,7 +32,6 @@ namespace Tensorflow public Tensor() { - _isCreatedInGraphMode = !tf.executing_eagerly(); } /// @@ -44,8 +43,6 @@ namespace Tensorflow _handle = handle; if (clone && handle != null) _handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer()); - - _isCreatedInGraphMode = !tf.executing_eagerly(); } /// @@ -59,13 +56,11 @@ namespace Tensorflow public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) { _handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer()); - _isCreatedInGraphMode = !tf.executing_eagerly(); } public unsafe Tensor(NDArray nd) { _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); - _isCreatedInGraphMode = !tf.executing_eagerly(); } #region scala @@ -107,13 +102,11 @@ namespace Tensorflow _value_index = value_index; _override_dtype = dtype; _id = ops.uid(); - _isCreatedInGraphMode = !tf.executing_eagerly(); } protected unsafe void InitTensor(Shape shape, TF_DataType dtype) { _handle = TF_NewTensor(shape, dtype, null); - _isCreatedInGraphMode = !tf.executing_eagerly(); } protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) @@ -122,13 +115,10 @@ namespace Tensorflow _handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); else _handle = TF_NewTensor(bytes, shape, dtype); - _isCreatedInGraphMode = !tf.executing_eagerly(); } protected unsafe void InitTensor(Array array, Shape? shape = null) { - _isCreatedInGraphMode = !tf.executing_eagerly(); - shape = shape ?? array.GetShape(); var dtype = array.GetDataType(); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 3f4ef8e5..19f91961 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -94,10 +94,6 @@ namespace Tensorflow /// public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle; - protected bool _isCreatedInGraphMode; - - public bool IsCreatedInGraphMode => _isCreatedInGraphMode; - /// /// Returns the shape of a tensor. /// diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 88f8fe84..ecd844d1 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -21,7 +21,6 @@ namespace Tensorflow public Shape shape => items.First().shape; public int rank => items.First().rank; public Graph graph => items.First().graph; - public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode; public bool IsList { get; set; } public int Length => items.Count(); diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 0872d69c..b270ec57 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -68,7 +68,7 @@ namespace Tensorflow // when this object is garbage collected the deleter will be too. This // means ResourceVariables can be part of reference cycles without those // cycles being uncollectable. - if (!handle.IsCreatedInGraphMode) + if (handle is EagerTensor) { _handle = handle.EagerTensorHandle.DangerousGetHandle(); eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 33894136..e9d58b6f 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -18,9 +18,11 @@ using System; using System.Collections.Generic; using System.Linq; using System.Threading; +using Tensorflow.Eager; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving; using Tensorflow.Keras.Utils; +using Tensorflow.NumPy; using Tensorflow.Train; using static Tensorflow.Binding; @@ -118,7 +120,7 @@ namespace Tensorflow.Keras.Engine bool _in_functional_construction_mode(Tensors inputs) { return tf.Context.executing_eagerly() - && inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count(); + && inputs.Count(x => x is not EagerTensor && x is not NDArray) == inputs.Count(); } public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) @@ -180,7 +182,7 @@ namespace Tensorflow.Keras.Engine tf.init_scope(); bool need_restore_mode = false; - if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function()) + if (inputs.Any(x => x is EagerTensor) || tf.Context.is_build_function()) { need_restore_mode = true; tf.Context.eager_mode(isFunc: tf.Context.is_build_function());