From 4252952208763f3d23e9a09765c758384b796978 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Tue, 4 Apr 2023 19:19:29 +0800 Subject: [PATCH] Fix the error that loaded concrete function does not work. --- .../Functions/ConcreteFunction.cs | 36 +++++++++++++------ src/TensorFlowNET.Keras/Engine/Sequential.cs | 3 +- .../Saving/KerasObjectLoader.cs | 7 ++-- .../SaveModel/SequentialModelLoad.cs | 3 +- 4 files changed, 32 insertions(+), 17 deletions(-) diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index 3cc27f25..69a31ba0 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -6,6 +6,7 @@ using Tensorflow.Eager; using Tensorflow.Framework.Models; using Tensorflow.Graphs; using Tensorflow.Train; +using Tensorflow.Util; using static Tensorflow.Binding; namespace Tensorflow.Functions @@ -21,6 +22,7 @@ namespace Tensorflow.Functions protected Dictionary _attrs; protected FunctionSpec _function_spec; protected FunctionSpec _pre_initialized_function_spec = null; + protected EagerDefinedFunction _inference_function; internal ForwardBackwardCall forward_backward; public Tensor[] Inputs => func_graph.Inputs; public Tensor[] CapturedInputs => func_graph.external_captures; @@ -39,6 +41,7 @@ namespace Tensorflow.Functions _captured_inputs = func_graph.external_captures; _attrs= new Dictionary(); _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); + _inference_function = _delayed_rewrite_functions.Forward(); } public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) @@ -49,6 +52,7 @@ namespace Tensorflow.Functions //ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); _attrs = attrs; _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); + _inference_function = _delayed_rewrite_functions.Forward(); } public ConcreteFunction(Func func, TF_DataType dtype) @@ -69,6 +73,7 @@ namespace Tensorflow.Functions _captured_inputs = func_graph.external_captures; _attrs = new Dictionary(); _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); + _inference_function = _delayed_rewrite_functions.Forward(); } public ConcreteFunction(Func func, TF_DataType dtype) @@ -92,6 +97,7 @@ namespace Tensorflow.Functions _captured_inputs = func_graph.external_captures; _attrs = new Dictionary(); _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); + _inference_function = _delayed_rewrite_functions.Forward(); } /*public ConcreteFunction(Func func, @@ -154,9 +160,10 @@ namespace Tensorflow.Functions { tensor_inputs.Add(arg); // If we're graph building, shape inference is on. - if (!executing_eagerly) - { - } + } + if (!executing_eagerly) + { + } tensor_inputs.AddRange(captured_inputs); @@ -166,12 +173,13 @@ namespace Tensorflow.Functions // No tape is watching; skip to running the function. if (possible_gradient_type == 0 && executing_eagerly) { - var attrs = new object[] - { - "executor_type", "", - "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() - }; - return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); + return _build_call_outputs(_inference_function.Call(args)); + //var attrs = new object[] + //{ + // "executor_type", "", + // "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() + //}; + //return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); } if (forward_backward == null) @@ -184,10 +192,11 @@ namespace Tensorflow.Functions } else { + // TODO(Rinne): add `default_graph._override_gradient_function`. flat_outputs = forward_function.Call(args_with_tangents); } forward_backward.Record(flat_outputs); - return flat_outputs; + return _build_call_outputs(flat_outputs); } public void AddTograph(Graph? g = null) @@ -262,6 +271,13 @@ namespace Tensorflow.Functions }; } + private Tensors _build_call_outputs(Tensors result) + { + // TODO(Rinne): dwal with `func_graph.structured_outputs` + + return result; + } + public override string ToString() => Name; } diff --git a/src/TensorFlowNET.Keras/Engine/Sequential.cs b/src/TensorFlowNET.Keras/Engine/Sequential.cs index 69665388..c9b8cfac 100644 --- a/src/TensorFlowNET.Keras/Engine/Sequential.cs +++ b/src/TensorFlowNET.Keras/Engine/Sequential.cs @@ -124,11 +124,12 @@ namespace Tensorflow.Keras.Engine if (set_inputs || _is_graph_network) { _init_graph_network(inputs, outputs); - _is_graph_network = true; + _graph_initialized = true; } else { _self_tracked_trackables.add(layer); + // TODO(Rinne): self._handle_deferred_layer_dependencies([layer]) } } diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index 3b5d3274..29c29405 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -189,11 +189,8 @@ namespace Tensorflow.Keras.Saving } // `model.__init__(layers, config["name"])`InitLayers(layers); - s = new Sequential(new SequentialArgs(){ - Layers = layers.Select(x => x as ILayer).ToList(), - Name = config["name"].ToObject() - }); - //s.Name = config["name"].ToObject(); + s.InitLayers(layers.Select(x => x as ILayer)); + s.Name = config["name"].ToObject(); if(s.inputs is null || s.inputs.Length == 0) { var first_layer = _get_child_layer_node_ids(model_id)[0]; diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index eeb5f9e4..17d864d2 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -62,10 +62,11 @@ public class SequentialModelLoad [TestMethod] public void Temp() { - var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func"); + var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func"); model.summary(); var x = tf.ones((2, 10)); var y = model.Apply(x); + Console.WriteLine(y); } }