| @@ -13,6 +13,7 @@ namespace Tensorflow.Functions | |||||
| public class ConcreteFunction | public class ConcreteFunction | ||||
| { | { | ||||
| FuncGraph func_graph; | FuncGraph func_graph; | ||||
| ForwardBackwardCall forward_backward; | |||||
| public Tensor[] Inputs => func_graph.Inputs; | public Tensor[] Inputs => func_graph.Inputs; | ||||
| public Tensor[] CapturedInputs => func_graph.external_captures; | public Tensor[] CapturedInputs => func_graph.external_captures; | ||||
| @@ -151,7 +152,8 @@ namespace Tensorflow.Functions | |||||
| return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | ||||
| } | } | ||||
| var forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||||
| if (forward_backward == null) | |||||
| forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly); | |||||
| var (forward_function, args_with_tangents) = forward_backward.Forward(); | var (forward_function, args_with_tangents) = forward_backward.Forward(); | ||||
| Tensors flat_outputs = null; | Tensors flat_outputs = null; | ||||
| if (executing_eagerly) | if (executing_eagerly) | ||||
| @@ -13,6 +13,7 @@ namespace Tensorflow.Functions | |||||
| Tensors _inference_args; | Tensors _inference_args; | ||||
| Tensors _input_tangents; | Tensors _input_tangents; | ||||
| bool _tape_watching; | bool _tape_watching; | ||||
| EagerDefinedFunction forward_function; | |||||
| public ForwardBackwardCall(TapeGradientFunctions functions, | public ForwardBackwardCall(TapeGradientFunctions functions, | ||||
| Tensors inference_args, | Tensors inference_args, | ||||
| @@ -22,10 +23,11 @@ namespace Tensorflow.Functions | |||||
| _inference_args = inference_args; | _inference_args = inference_args; | ||||
| _tape_watching = tape_watching; | _tape_watching = tape_watching; | ||||
| } | } | ||||
| public (EagerDefinedFunction, Tensors) Forward() | public (EagerDefinedFunction, Tensors) Forward() | ||||
| { | { | ||||
| var forward_function = _functions.Forward(_inference_args); | |||||
| if (forward_function == null) | |||||
| forward_function = _functions.Forward(_inference_args); | |||||
| return (forward_function, _inference_args); | return (forward_function, _inference_args); | ||||
| } | } | ||||
| @@ -25,6 +25,7 @@ namespace Tensorflow.Functions | |||||
| protected List<int> _forwardprop_output_indices; | protected List<int> _forwardprop_output_indices; | ||||
| protected int _num_forwardprop_outputs; | protected int _num_forwardprop_outputs; | ||||
| protected ConcreteFunction _backward; | protected ConcreteFunction _backward; | ||||
| BackwardFunction _backward_function_wrapper; | |||||
| public TapeGradientFunctions(FuncGraph func_graph, | public TapeGradientFunctions(FuncGraph func_graph, | ||||
| bool need_gradients_for_jvps) | bool need_gradients_for_jvps) | ||||
| @@ -58,60 +59,66 @@ namespace Tensorflow.Functions | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | (BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs) | ||||
| { | { | ||||
| var capture_mapping = new Dictionary<long, Tensor>(); | |||||
| foreach(var (i, output) in enumerate(outputs)) | |||||
| capture_mapping[forward_graph.Outputs[i].Id] = output; | |||||
| var remapped_captures = new Tensors(); | |||||
| foreach(var capture in backward.CapturedInputs) | |||||
| { | |||||
| if (capture_mapping.ContainsKey(capture.Id)) | |||||
| remapped_captures.Add(capture_mapping[capture.Id]); | |||||
| } | |||||
| var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length; | ||||
| var recorded_outputs = new Tensors(); | var recorded_outputs = new Tensors(); | ||||
| var relevant_outputs = outputs; | |||||
| var trainable_recorded_outputs = 0; | var trainable_recorded_outputs = 0; | ||||
| var skip_positions = new List<int>(); | |||||
| foreach (var (output_index, output) in enumerate(relevant_outputs)) | |||||
| foreach (var (output_index, output) in enumerate(outputs)) | |||||
| { | { | ||||
| if (trainable_recorded_outputs < backward_function_inputs) | if (trainable_recorded_outputs < backward_function_inputs) | ||||
| recorded_outputs.Add(output); | recorded_outputs.Add(output); | ||||
| if (gradients_util.IsTrainable(output)) | if (gradients_util.IsTrainable(output)) | ||||
| trainable_recorded_outputs += 1; | trainable_recorded_outputs += 1; | ||||
| else | |||||
| skip_positions.Add(output_index); | |||||
| } | } | ||||
| BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) => | |||||
| if(_backward_function_wrapper == null) | |||||
| { | { | ||||
| var processed_args = new Tensors(); | |||||
| var input_index = 0; | |||||
| foreach (var (output_index, arg) in enumerate(args)) | |||||
| var capture_mapping = new Dictionary<long, Tensor>(); | |||||
| foreach (var (i, output) in enumerate(outputs)) | |||||
| capture_mapping[forward_graph.Outputs[i].Id] = output; | |||||
| var remapped_captures = new Tensors(); | |||||
| foreach (var capture in backward.CapturedInputs) | |||||
| { | { | ||||
| if (skip_positions.Contains(output_index)) | |||||
| continue; | |||||
| if (arg == null) | |||||
| throw new NotImplementedException(""); | |||||
| processed_args.Add(arg); | |||||
| input_index += 1; | |||||
| if (input_index >= backward_function_inputs) | |||||
| break; | |||||
| if (capture_mapping.ContainsKey(capture.Id)) | |||||
| remapped_captures.Add(capture_mapping[capture.Id]); | |||||
| } | } | ||||
| tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||||
| var gradients = backward.CallFlat(processed_args, remapped_captures); | |||||
| foreach (var unneeded_gradient_index in unneeded_gradients) | |||||
| var skip_positions = new List<int>(); | |||||
| foreach (var (output_index, output) in enumerate(outputs)) | |||||
| { | { | ||||
| var index = Convert.ToInt32(unneeded_gradient_index); | |||||
| if (gradients.Length <= index) | |||||
| gradients.Insert(index, null); | |||||
| if (!gradients_util.IsTrainable(output)) | |||||
| skip_positions.Add(output_index); | |||||
| } | } | ||||
| return gradients; | |||||
| }; | |||||
| _backward_function_wrapper = (args, unneeded_gradients) => | |||||
| { | |||||
| var processed_args = new Tensors(); | |||||
| var input_index = 0; | |||||
| foreach (var (output_index, arg) in enumerate(args)) | |||||
| { | |||||
| if (skip_positions.Contains(output_index)) | |||||
| continue; | |||||
| if (arg == null) | |||||
| throw new NotImplementedException(""); | |||||
| processed_args.Add(arg); | |||||
| input_index += 1; | |||||
| if (input_index >= backward_function_inputs) | |||||
| break; | |||||
| } | |||||
| tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||||
| var gradients = backward.CallFlat(processed_args, remapped_captures); | |||||
| foreach (var unneeded_gradient_index in unneeded_gradients) | |||||
| { | |||||
| var index = Convert.ToInt32(unneeded_gradient_index); | |||||
| if (gradients.Length <= index) | |||||
| gradients.Insert(index, null); | |||||
| } | |||||
| return gradients; | |||||
| }; | |||||
| } | |||||
| return (_backward_function_wrapper, recorded_outputs); | return (_backward_function_wrapper, recorded_outputs); | ||||
| } | } | ||||
| @@ -376,6 +376,10 @@ namespace Tensorflow | |||||
| public static int uid_function() | public static int uid_function() | ||||
| => Interlocked.Increment(ref uid_number_for_function); | => Interlocked.Increment(ref uid_number_for_function); | ||||
| static int uid_number_for_layer = 0; | |||||
| public static int uid_layer() | |||||
| => Interlocked.Increment(ref uid_number_for_layer); | |||||
| public static void reset_uid() | public static void reset_uid() | ||||
| { | { | ||||
| uid_number = -1; | uid_number = -1; | ||||
| @@ -66,6 +66,8 @@ namespace Tensorflow.Keras.Engine | |||||
| protected List<IVariableV1> non_trainable_weights; | protected List<IVariableV1> non_trainable_weights; | ||||
| public List<IVariableV1> non_trainable_variables => non_trainable_weights; | public List<IVariableV1> non_trainable_variables => non_trainable_weights; | ||||
| protected int id; | |||||
| public int Id => id; | |||||
| protected string name; | protected string name; | ||||
| protected string base_name; | protected string base_name; | ||||
| public string Name => name; | public string Name => name; | ||||
| @@ -96,6 +98,7 @@ namespace Tensorflow.Keras.Engine | |||||
| built = false; | built = false; | ||||
| SupportsMasking = false; | SupportsMasking = false; | ||||
| id = ops.uid_layer(); | |||||
| _init_set_name(args.Name); | _init_set_name(args.Name); | ||||
| trainable_weights = new List<IVariableV1>(); | trainable_weights = new List<IVariableV1>(); | ||||
| non_trainable_weights = new List<IVariableV1>(); | non_trainable_weights = new List<IVariableV1>(); | ||||
| @@ -8,6 +8,7 @@ using Tensorflow.Graphs; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Tensorflow.Functions; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| @@ -35,10 +36,39 @@ namespace Tensorflow.Keras.Layers | |||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | if (tf.Context.executing_eagerly()) | ||||
| return _defun_call(inputs); | |||||
| return DeFunCall(inputs); | |||||
| return MakOp(inputs); | return MakOp(inputs); | ||||
| } | } | ||||
| ConcreteFunction function; | |||||
| Tensors DeFunCall(Tensors inputs) | |||||
| { | |||||
| if(function == null) | |||||
| { | |||||
| function = new ConcreteFunction(name); | |||||
| function.Enter(); | |||||
| int i = 0; | |||||
| var graph_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: $"defun_inputs_{i++}")).ToArray(); | |||||
| var graph_outputs = MakOp(graph_inputs); | |||||
| graph_outputs = mark_as_return(graph_outputs); | |||||
| function.ToGraph(graph_inputs, graph_outputs); | |||||
| function.Exit(); | |||||
| } | |||||
| var outputs = function.FilteredCall(inputs); | |||||
| return outputs; | |||||
| } | |||||
| Tensors mark_as_return(Tensors tensors) | |||||
| { | |||||
| var result = new Tensors(); | |||||
| foreach (var tensor in tensors) | |||||
| result.Add(array_ops.identity(tensor)); | |||||
| return result; | |||||
| } | |||||
| [AutoGraph] | [AutoGraph] | ||||
| Tensors _defun_call(Tensors inputs) | Tensors _defun_call(Tensors inputs) | ||||
| => MakOp(inputs); | => MakOp(inputs); | ||||