| @@ -2,7 +2,9 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Eager; | |||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using Tensorflow.NumPy; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.tensorflow; | using static Tensorflow.tensorflow; | ||||
| @@ -148,7 +150,7 @@ namespace Tensorflow.Functions | |||||
| src_graph: _func_graph); | src_graph: _func_graph); | ||||
| var captures_from_forward = backwards_graph.external_captures | var captures_from_forward = backwards_graph.external_captures | ||||
| .Where(x => x.IsCreatedInGraphMode && x.graph == _func_graph) | |||||
| .Where(x => x is not EagerTensor && x is not NDArray && x.graph == _func_graph) | |||||
| .ToArray(); | .ToArray(); | ||||
| foreach(var capture in captures_from_forward) | foreach(var capture in captures_from_forward) | ||||
| { | { | ||||
| @@ -32,7 +32,6 @@ namespace Tensorflow | |||||
| public Tensor() | public Tensor() | ||||
| { | { | ||||
| _isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -44,8 +43,6 @@ namespace Tensorflow | |||||
| _handle = handle; | _handle = handle; | ||||
| if (clone && handle != null) | if (clone && handle != null) | ||||
| _handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer()); | _handle = TF_NewTensor(shape, dtype, data: TensorDataPointer.ToPointer()); | ||||
| _isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -59,13 +56,11 @@ namespace Tensorflow | |||||
| public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) | public unsafe Tensor(IntPtr data_ptr, Shape shape, TF_DataType dtype) | ||||
| { | { | ||||
| _handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer()); | _handle = TF_NewTensor(shape, dtype, data: data_ptr.ToPointer()); | ||||
| _isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| public unsafe Tensor(NDArray nd) | public unsafe Tensor(NDArray nd) | ||||
| { | { | ||||
| _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | _handle = TF_NewTensor(nd.shape, nd.dtype, nd.data.ToPointer()); | ||||
| _isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| #region scala | #region scala | ||||
| @@ -107,13 +102,11 @@ namespace Tensorflow | |||||
| _value_index = value_index; | _value_index = value_index; | ||||
| _override_dtype = dtype; | _override_dtype = dtype; | ||||
| _id = ops.uid(); | _id = ops.uid(); | ||||
| _isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | protected unsafe void InitTensor(Shape shape, TF_DataType dtype) | ||||
| { | { | ||||
| _handle = TF_NewTensor(shape, dtype, null); | _handle = TF_NewTensor(shape, dtype, null); | ||||
| _isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) | protected unsafe void InitTensor(Shape shape, byte[] bytes, TF_DataType dtype) | ||||
| @@ -122,13 +115,10 @@ namespace Tensorflow | |||||
| _handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); | _handle = StringTensor(new byte[][] { bytes }, Shape.Scalar); | ||||
| else | else | ||||
| _handle = TF_NewTensor(bytes, shape, dtype); | _handle = TF_NewTensor(bytes, shape, dtype); | ||||
| _isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| } | } | ||||
| protected unsafe void InitTensor(Array array, Shape? shape = null) | protected unsafe void InitTensor(Array array, Shape? shape = null) | ||||
| { | { | ||||
| _isCreatedInGraphMode = !tf.executing_eagerly(); | |||||
| shape = shape ?? array.GetShape(); | shape = shape ?? array.GetShape(); | ||||
| var dtype = array.GetDataType(); | var dtype = array.GetDataType(); | ||||
| @@ -94,10 +94,6 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle; | public SafeEagerTensorHandle EagerTensorHandle => _eagerTensorHandle; | ||||
| protected bool _isCreatedInGraphMode; | |||||
| public bool IsCreatedInGraphMode => _isCreatedInGraphMode; | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the shape of a tensor. | /// Returns the shape of a tensor. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -21,7 +21,6 @@ namespace Tensorflow | |||||
| public Shape shape => items.First().shape; | public Shape shape => items.First().shape; | ||||
| public int rank => items.First().rank; | public int rank => items.First().rank; | ||||
| public Graph graph => items.First().graph; | public Graph graph => items.First().graph; | ||||
| public bool IsCreatedInGraphMode => items.First().IsCreatedInGraphMode; | |||||
| public bool IsList { get; set; } | public bool IsList { get; set; } | ||||
| public int Length => items.Count(); | public int Length => items.Count(); | ||||
| @@ -68,7 +68,7 @@ namespace Tensorflow | |||||
| // when this object is garbage collected the deleter will be too. This | // when this object is garbage collected the deleter will be too. This | ||||
| // means ResourceVariables can be part of reference cycles without those | // means ResourceVariables can be part of reference cycles without those | ||||
| // cycles being uncollectable. | // cycles being uncollectable. | ||||
| if (!handle.IsCreatedInGraphMode) | |||||
| if (handle is EagerTensor) | |||||
| { | { | ||||
| _handle = handle.EagerTensorHandle.DangerousGetHandle(); | _handle = handle.EagerTensorHandle.DangerousGetHandle(); | ||||
| eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | eager_resource_deleter = new EagerResourceDeleter(handle, handle.Device); | ||||
| @@ -18,9 +18,11 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Threading; | using System.Threading; | ||||
| using Tensorflow.Eager; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using Tensorflow.NumPy; | |||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -118,7 +120,7 @@ namespace Tensorflow.Keras.Engine | |||||
| bool _in_functional_construction_mode(Tensors inputs) | bool _in_functional_construction_mode(Tensors inputs) | ||||
| { | { | ||||
| return tf.Context.executing_eagerly() | return tf.Context.executing_eagerly() | ||||
| && inputs.Count(x => x.IsCreatedInGraphMode) == inputs.Count(); | |||||
| && inputs.Count(x => x is not EagerTensor && x is not NDArray) == inputs.Count(); | |||||
| } | } | ||||
| public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) | public void SetConnectivityMetadata(Tensors inputs, Tensors outputs) | ||||
| @@ -180,7 +182,7 @@ namespace Tensorflow.Keras.Engine | |||||
| tf.init_scope(); | tf.init_scope(); | ||||
| bool need_restore_mode = false; | bool need_restore_mode = false; | ||||
| if (!inputs.IsCreatedInGraphMode || tf.Context.is_build_function()) | |||||
| if (inputs.Any(x => x is EagerTensor) || tf.Context.is_build_function()) | |||||
| { | { | ||||
| need_restore_mode = true; | need_restore_mode = true; | ||||
| tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); | ||||