diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 855364fa..17e85306 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -105,6 +105,11 @@ namespace Tensorflow.Contexts context_switches.Pop(); } + public void reset_context() + { + c_api.TFE_ContextClearCaches(_handle); + } + public void Dispose() => _handle.Dispose(); } diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index bc71eee4..e9203878 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -34,6 +34,7 @@ namespace Tensorflow.Functions public ConcreteFunction(string name) { func_graph = new FuncGraph(name); + func_graph.as_default(); } public ConcreteFunction(FuncGraph graph, Dictionary attrs) @@ -48,17 +49,16 @@ namespace Tensorflow.Functions string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; // IntPtr func_handle; - using (var graph = new FuncGraph(func_name)) - { - var input = tf.placeholder(dtype); - var output = func(input); - - var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); - _handle = graph.ToGraph(opers, - new[] { input }, - new[] { output }, - null); - } + using var graph = new FuncGraph(func_name); + graph.as_default(); + var input = tf.placeholder(dtype); + var output = func(input); + + var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + _handle = graph.ToGraph(opers, + new[] { input }, + new[] { output }, + null); } public ConcreteFunction(Func func, TF_DataType dtype) @@ -66,19 +66,19 @@ namespace Tensorflow.Functions string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; // IntPtr func_handle; - using (var graph = new FuncGraph(func_name)) - { - var input = tf.placeholder(dtype); - var output = func(input); + using var graph = new FuncGraph(func_name); + graph.as_default(); - OutputStructure = output.structure; + var input = tf.placeholder(dtype); + var output = func(input); - var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); - _handle = graph.ToGraph(opers, - new[] { input }, - new[] { output.variant_tensor }, - null); - } + OutputStructure = output.structure; + + var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + _handle = graph.ToGraph(opers, + new[] { input }, + new[] { output.variant_tensor }, + null); } public ConcreteFunction(Func func, @@ -87,22 +87,22 @@ namespace Tensorflow.Functions string func_name = $"autograph_{Guid.NewGuid()}_{func.Method.Name}"; // IntPtr func_handle; - using (var graph = new FuncGraph(func_name)) - { - var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); - var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); - var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); - var outputs = func(input1, (input2, input3)); - - Outputs = new[] { outputs.Item1, outputs.Item2 }; - OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; - - var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); - _handle = graph.ToGraph(opers, - new[] { input1, input2, input3 }, - new[] { outputs.Item1, outputs.Item2 }, - null); - } + using var graph = new FuncGraph(func_name); + graph.as_default(); + + var input1 = tf.placeholder(dtypes[0], shape: shapes[0], name: "args"); + var input2 = tf.placeholder(dtypes[1], shape: shapes[1], name: "args"); + var input3 = tf.placeholder(dtypes[2], shape: shapes[2], name: "args"); + var outputs = func(input1, (input2, input3)); + + Outputs = new[] { outputs.Item1, outputs.Item2 }; + OutputStructure = new[] { outputs.Item1.ToTensorSpec(), outputs.Item2.ToTensorSpec() }; + + var opers = graph._nodes_by_name.Values.Select(x => x as Operation).ToArray(); + _handle = graph.ToGraph(opers, + new[] { input1, input2, input3 }, + new[] { outputs.Item1, outputs.Item2 }, + null); } public void ToGraph(Tensors inputs, Tensors outputs) diff --git a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs index f615f6a4..e0896253 100644 --- a/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs +++ b/src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs @@ -26,6 +26,7 @@ namespace Tensorflow.Functions var output_names = new string[0]; _func_graph = new FuncGraph(graph, name, attrs); + _func_graph.as_default(); _func_graph.ToGraph(operations, inputs, outputs, output_names); } diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 45b0de26..4559fc5d 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -85,6 +85,7 @@ namespace Tensorflow.Functions var gradients_wrt_outputs = new List(); var backwards_graph = new FuncGraph($"{_BACKWARD_PREFIX}_{ops.uid()}"); + backwards_graph.as_default(); foreach (var output in trainable_outputs) gradients_wrt_outputs.Add(tf.placeholder(output.dtype, output.shape)); var gradients_wrt_inputs = gradients_util._GradientsHelper(trainable_outputs.ToArray(), diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs index bbec00ea..901cbd6f 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraph.cs @@ -13,6 +13,7 @@ namespace Tensorflow.Graphs // IntPtr func_handle; using (var graph = new FuncGraph(func_name)) { + graph.as_default(); var input = tf.placeholder(tf.int32); var output = func(input); @@ -43,6 +44,7 @@ namespace Tensorflow.Graphs // IntPtr func_handle; using (var graph = new FuncGraph(func_name)) { + graph.as_default(); var input1 = tf.placeholder(tf.int32); var input2 = tf.placeholder(tf.int32); var output = func(input1, input2); diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index b80c659f..0d0ac581 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -30,9 +30,6 @@ namespace Tensorflow.Graphs public Tensor[] internal_captures() => _captures.Select(x => x.Value.Item2).ToArray(); - // new Dictionary _captures = new Dictionary(); - // public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); - /// /// Construct a new FuncGraph. /// @@ -43,8 +40,6 @@ namespace Tensorflow.Graphs outer_graph = outer_graph.OuterGraph; _graph_key = name; building_function = true; - tf.Context.graph_mode(); - as_default(); } public FuncGraph(IntPtr handle, string name, Dictionary attrs) : base() @@ -58,9 +53,6 @@ namespace Tensorflow.Graphs // Will to test if FuncGraph has memory leak // c_api.TF_DeleteGraph(_handle); _handle = handle; - - tf.Context.graph_mode(); - as_default(); } public IntPtr ToGraph(Operation[] opers, @@ -110,11 +102,21 @@ namespace Tensorflow.Graphs return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); } - public Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid) + const int _EAGER_CONST_THRESHOLD = 128; + public Tensor capture(Tensor tensor, string name = null, TensorShape shape = null) { if(tensor is EagerTensor) { - throw new NotImplementedException(""); + if (name == null) + name = ops.uid().ToString(); + + // Small EagerTensors are captured with Const ops + if (dtypes.is_value_dtype(tensor.dtype) + && (tensor.rank == 0 || tensor.size < _EAGER_CONST_THRESHOLD)) + return capture_eager_tensor(tensor, name); + + // Large EagerTensors and resources are captured with Placeholder ops + return _capture_helper(tensor, name, shape: shape); } if(tensor.graph != this) @@ -137,6 +139,9 @@ namespace Tensorflow.Graphs return tensor; } + Tensor capture_eager_tensor(Tensor tensor, string name) + => throw new NotImplementedException(""); + Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) { Tensor placeholder = null; @@ -190,7 +195,8 @@ namespace Tensorflow.Graphs if (dtype == TF_DataType.DtInvalid) dtype = value.dtype; - var placeholder = tf_with(ops.control_dependencies(null), ctl => array_ops.placeholder(dtype, shape: shape, name: name)); + var placeholder = tf_with(ops.control_dependencies(null), ctl + => array_ops.placeholder(dtype, shape: shape, name: name)); // custom_gradient.copy_handle_data(value, placeholder) return placeholder; } @@ -211,6 +217,13 @@ namespace Tensorflow.Graphs } } + public override Graph as_default() + { + tf.Context.graph_mode(isFunc: true); + ops.set_default_graph(this); + return this; + } + protected override void DisposeManagedResources() { base.DisposeManagedResources(); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 2c0d0944..81cb08c4 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -148,7 +148,7 @@ namespace Tensorflow /// Returns a context manager that makes this `Graph` the default graph. /// /// - public Graph as_default() + public virtual Graph as_default() { return ops.set_default_graph(this); } diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 507f260c..cd6a456e 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 2.2.0 - 0.31.2 + 0.32.0 8.0 Haiping Chen, Meinrad Recheis, Eli Belash SciSharp STACK @@ -15,7 +15,7 @@ git http://scisharpstack.org https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 - TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#, TF.NET + TensorFlow, SciSharp, Machine Learning, TensorFlow.NET, TF.NET, AI Google's TensorFlow full binding in .NET Standard. Building, training and infering deep learning models. https://tensorflownet.readthedocs.io diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 98e0066a..f1fe530a 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -293,5 +293,12 @@ namespace Tensorflow else return self; } + + public static bool is_value_dtype(this TF_DataType type) + { + return ((int)type >= 1 && (int)type <= 19) + || type == TF_DataType.TF_UINT32 + || type == TF_DataType.TF_UINT64; + } } } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index c2f4fec8..725a50b6 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -113,13 +113,13 @@ namespace Tensorflow { if (tf.executing_eagerly()) return eager_tensor; - /*else + else { var graph = get_default_graph(); if (!graph.building_function) throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); return (graph as FuncGraph).capture(eager_tensor, name: name); - }*/ + } } Tensor ret = value switch diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 5f54b37e..20bc99f6 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -83,6 +83,7 @@ namespace Tensorflow.Keras { if (_GRAPH == null) _GRAPH = new FuncGraph("keras_graph"); + return _GRAPH; } return ops.get_default_graph(); diff --git a/src/TensorFlowNET.Keras/Engine/CallContext.cs b/src/TensorFlowNET.Keras/Engine/CallContext.cs index 3768ed52..99dd7901 100644 --- a/src/TensorFlowNET.Keras/Engine/CallContext.cs +++ b/src/TensorFlowNET.Keras/Engine/CallContext.cs @@ -1,10 +1,12 @@ -namespace Tensorflow.Keras.Engine +using static Tensorflow.Binding; + +namespace Tensorflow.Keras.Engine { public class CallContext { - public CallContextManager enter() + public CallContextManager enter(bool build_graph) { - return new CallContextManager(); + return new CallContextManager(build_graph); } } } diff --git a/src/TensorFlowNET.Keras/Engine/CallContextManager.cs b/src/TensorFlowNET.Keras/Engine/CallContextManager.cs index 1d76cf02..79cb4b30 100644 --- a/src/TensorFlowNET.Keras/Engine/CallContextManager.cs +++ b/src/TensorFlowNET.Keras/Engine/CallContextManager.cs @@ -1,12 +1,20 @@ using System; +using static Tensorflow.Binding; namespace Tensorflow.Keras.Engine { public class CallContextManager : IDisposable { - public void Dispose() + bool _build_graph; + + public CallContextManager(bool build_graph) { + _build_graph = build_graph; + } + public void Dispose() + { + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs index f04e5330..63155fa3 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Apply.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine Tensors outputs = null; var eager = tf.executing_eagerly(); - using var ctxManager = CallContext.enter(); + using var ctxManager = CallContext.enter(build_graph: false); string nameScope = ""; if (eager) @@ -33,9 +33,6 @@ namespace Tensorflow.Keras.Engine else nameScope = _name_scope(); - if (!inputs.IsEagerTensor) - tf.Context.graph_mode(); - tf_with(ops.name_scope(nameScope), scope => { if (!built) @@ -48,9 +45,6 @@ namespace Tensorflow.Keras.Engine _set_mask_metadata(inputs, outputs, null); }); - if (!inputs.IsEagerTensor) - tf.Context.restore_mode(); - return outputs; } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs b/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs index 5ff2ca04..25a7cd7a 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs @@ -21,12 +21,10 @@ namespace Tensorflow.Keras.Engine base_layer_utils.create_keras_history(inputs); Tensors outputs = null; - using var ctxManager = CallContext.enter(); + using var ctxManager = CallContext.enter(build_graph: true); - // using var graph = keras.backend.get_graph(); - - if (!inputs.IsEagerTensor) - tf.Context.graph_mode(isFunc: true); + var graph = keras.backend.get_graph(); + graph.as_default(); tf_with(ops.name_scope(_name_scope()), scope => { @@ -48,8 +46,7 @@ namespace Tensorflow.Keras.Engine _set_mask_metadata(inputs, outputs, null); }); - if (!inputs.IsEagerTensor) - tf.Context.restore_mode(); + tf.Context.restore_mode(); return outputs; } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index fa072da9..8daf60a2 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -180,7 +180,7 @@ namespace Tensorflow.Keras.Engine if (inputs.IsEagerTensor || tf.Context.is_build_function()) { need_restore_mode = true; - tf.Context.eager_mode(); + tf.Context.eager_mode(isFunc: tf.Context.is_build_function()); } build(inputs); diff --git a/src/TensorFlowNET.Keras/Engine/Model.cs b/src/TensorFlowNET.Keras/Engine/Model.cs index eac7a139..3ac9cd38 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.cs @@ -4,6 +4,7 @@ using Tensorflow.Keras.Engine.DataAdapters; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Optimizers; using static Tensorflow.Binding; +using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Engine { @@ -51,6 +52,7 @@ namespace Tensorflow.Keras.Engine { // Used to cache `trainable` attr of `Layer`s for `fit`. _compiled_trainable_state = _get_trainable_state(); + keras.backend._GRAPH = null; } void _init_batch_counters() diff --git a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj index e4c42061..14c5719d 100644 --- a/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj +++ b/src/TensorFlowNET.Keras/Tensorflow.Keras.csproj @@ -24,7 +24,7 @@ Keras is an API designed for human beings, not machines. Keras follows best practices for reducing cognitive load: it offers consistent & simple APIs, it minimizes the number of user actions required for common use cases, and it provides clear & actionable error messages. SciSharp STACK true - tensorflow, keras, deep learning, machine learning + tensorflow, keras, deep learning, machine learning, scisharp true Git true