diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index affc0b61..c52d0b5f 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -13,6 +13,7 @@ namespace Tensorflow.Functions public class ConcreteFunction { FuncGraph func_graph; + ForwardBackwardCall forward_backward; public Tensor[] Inputs => func_graph.Inputs; public Tensor[] CapturedInputs => func_graph.external_captures; @@ -151,7 +152,8 @@ namespace Tensorflow.Functions return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); } - var forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); + if (forward_backward == null) + forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); var (forward_function, args_with_tangents) = forward_backward.Forward(); Tensors flat_outputs = null; if (executing_eagerly) diff --git a/src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs b/src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs index cb4d6f1c..392c0695 100644 --- a/src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs +++ b/src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs @@ -13,6 +13,7 @@ namespace Tensorflow.Functions Tensors _inference_args; Tensors _input_tangents; bool _tape_watching; + EagerDefinedFunction forward_function; public ForwardBackwardCall(TapeGradientFunctions functions, Tensors inference_args, @@ -22,10 +23,11 @@ namespace Tensorflow.Functions _inference_args = inference_args; _tape_watching = tape_watching; } - + public (EagerDefinedFunction, Tensors) Forward() { - var forward_function = _functions.Forward(_inference_args); + if (forward_function == null) + forward_function = _functions.Forward(_inference_args); return (forward_function, _inference_args); } diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index c803b2b3..33f3d692 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -25,6 +25,7 @@ namespace Tensorflow.Functions protected List _forwardprop_output_indices; protected int _num_forwardprop_outputs; protected ConcreteFunction _backward; + BackwardFunction _backward_function_wrapper; public TapeGradientFunctions(FuncGraph func_graph, bool need_gradients_for_jvps) @@ -58,60 +59,66 @@ namespace Tensorflow.Functions /// (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) { - var capture_mapping = new Dictionary(); - foreach(var (i, output) in enumerate(outputs)) - capture_mapping[forward_graph.Outputs[i].Id] = output; - - var remapped_captures = new Tensors(); - foreach(var capture in backward.CapturedInputs) - { - if (capture_mapping.ContainsKey(capture.Id)) - remapped_captures.Add(capture_mapping[capture.Id]); - } - var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; var recorded_outputs = new Tensors(); - var relevant_outputs = outputs; var trainable_recorded_outputs = 0; - var skip_positions = new List(); - foreach (var (output_index, output) in enumerate(relevant_outputs)) + foreach (var (output_index, output) in enumerate(outputs)) { if (trainable_recorded_outputs < backward_function_inputs) recorded_outputs.Add(output); if (gradients_util.IsTrainable(output)) trainable_recorded_outputs += 1; - else - skip_positions.Add(output_index); } - BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) => + if(_backward_function_wrapper == null) { - var processed_args = new Tensors(); - var input_index = 0; - foreach (var (output_index, arg) in enumerate(args)) + var capture_mapping = new Dictionary(); + foreach (var (i, output) in enumerate(outputs)) + capture_mapping[forward_graph.Outputs[i].Id] = output; + + var remapped_captures = new Tensors(); + foreach (var capture in backward.CapturedInputs) { - if (skip_positions.Contains(output_index)) - continue; - if (arg == null) - throw new NotImplementedException(""); - processed_args.Add(arg); - input_index += 1; - if (input_index >= backward_function_inputs) - break; + if (capture_mapping.ContainsKey(capture.Id)) + remapped_captures.Add(capture_mapping[capture.Id]); } - tf.Logger.Debug($"Invoke backward function: {backward.Name}"); - var gradients = backward.CallFlat(processed_args, remapped_captures); - - foreach (var unneeded_gradient_index in unneeded_gradients) + var skip_positions = new List(); + foreach (var (output_index, output) in enumerate(outputs)) { - var index = Convert.ToInt32(unneeded_gradient_index); - if (gradients.Length <= index) - gradients.Insert(index, null); + if (!gradients_util.IsTrainable(output)) + skip_positions.Add(output_index); } - return gradients; - }; + _backward_function_wrapper = (args, unneeded_gradients) => + { + var processed_args = new Tensors(); + var input_index = 0; + foreach (var (output_index, arg) in enumerate(args)) + { + if (skip_positions.Contains(output_index)) + continue; + if (arg == null) + throw new NotImplementedException(""); + processed_args.Add(arg); + input_index += 1; + if (input_index >= backward_function_inputs) + break; + } + + tf.Logger.Debug($"Invoke backward function: {backward.Name}"); + var gradients = backward.CallFlat(processed_args, remapped_captures); + + foreach (var unneeded_gradient_index in unneeded_gradients) + { + var index = Convert.ToInt32(unneeded_gradient_index); + if (gradients.Length <= index) + gradients.Insert(index, null); + } + + return gradients; + }; + } return (_backward_function_wrapper, recorded_outputs); } diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 20b4f124..5bb7cb3a 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -376,6 +376,10 @@ namespace Tensorflow public static int uid_function() => Interlocked.Increment(ref uid_number_for_function); + static int uid_number_for_layer = 0; + public static int uid_layer() + => Interlocked.Increment(ref uid_number_for_layer); + public static void reset_uid() { uid_number = -1; diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 50d0fbe9..33894136 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -66,6 +66,8 @@ namespace Tensorflow.Keras.Engine protected List non_trainable_weights; public List non_trainable_variables => non_trainable_weights; + protected int id; + public int Id => id; protected string name; protected string base_name; public string Name => name; @@ -96,6 +98,7 @@ namespace Tensorflow.Keras.Engine built = false; SupportsMasking = false; + id = ops.uid_layer(); _init_set_name(args.Name); trainable_weights = new List(); non_trainable_weights = new List(); diff --git a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs index 288911c7..024a8fc5 100644 --- a/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs @@ -8,6 +8,7 @@ using Tensorflow.Graphs; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using static Tensorflow.Binding; +using Tensorflow.Functions; namespace Tensorflow.Keras.Layers { @@ -35,10 +36,39 @@ namespace Tensorflow.Keras.Layers protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { if (tf.Context.executing_eagerly()) - return _defun_call(inputs); + return DeFunCall(inputs); return MakOp(inputs); } + ConcreteFunction function; + Tensors DeFunCall(Tensors inputs) + { + if(function == null) + { + function = new ConcreteFunction(name); + function.Enter(); + + int i = 0; + var graph_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: $"defun_inputs_{i++}")).ToArray(); + var graph_outputs = MakOp(graph_inputs); + graph_outputs = mark_as_return(graph_outputs); + + function.ToGraph(graph_inputs, graph_outputs); + function.Exit(); + } + + var outputs = function.FilteredCall(inputs); + return outputs; + } + + Tensors mark_as_return(Tensors tensors) + { + var result = new Tensors(); + foreach (var tensor in tensors) + result.Add(array_ops.identity(tensor)); + return result; + } + [AutoGraph] Tensors _defun_call(Tensors inputs) => MakOp(inputs);