From 505565550fd6501ca9e3a64f8d6589470a549f17 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 26 Dec 2020 20:42:42 -0600 Subject: [PATCH] Allow Tensors to extend. --- .../Functions/ConcreteFunction.cs | 2 +- .../Functions/TapeGradientFunctions.cs | 11 +++++++- .../Graphs/AutoGraphAttribute.cs | 6 ++--- src/TensorFlowNET.Core/Graphs/FuncGraph.cs | 27 ++++++++++++------- src/TensorFlowNET.Core/Graphs/Graph.cs | 20 +------------- src/TensorFlowNET.Core/Tensors/Tensors.cs | 22 +++++++-------- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 7 ++--- src/TensorFlowNET.Keras/Engine/Functional.cs | 4 +-- .../Engine/TensorFlowOpLayer.cs | 2 +- 9 files changed, 49 insertions(+), 52 deletions(-) diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index e9ca28d8..bc71eee4 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -139,7 +139,7 @@ namespace Tensorflow.Functions "executor_type", "", "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() }; - return tf.Runner.Execute(tf.Context, func_graph.FuncName, 1, args, attrs); + return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); } ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 0a98d91d..89f87c62 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Text; using Tensorflow.Graphs; using static Tensorflow.Binding; @@ -61,7 +62,7 @@ namespace Tensorflow.Functions processed_args.add(arg); input_index += 1; } - + tf.Logger.Debug($"Invoke backward function: {backward.Name}"); return backward.CallFlat(processed_args.ToArray(), outputs); }; @@ -91,6 +92,14 @@ namespace Tensorflow.Functions grad_ys: gradients_wrt_outputs.ToArray(), src_graph: _func_graph); + var captures_from_forward = backwards_graph.external_captures() + .Where(x => !x.IsEagerTensor && x.graph == _func_graph) + .ToArray(); + foreach(var capture in captures_from_forward) + { + _func_graph.Outputs.Add(capture); + } + var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; var backward_function_attr = new Dictionary(); backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; diff --git a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs index 7c5ab96b..3d04b8a3 100644 --- a/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs +++ b/src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs @@ -18,7 +18,7 @@ namespace Tensorflow.Graphs public override void OnEntry(MethodExecutionArgs args) { - func_name = $"autograph_{args.Instance.GetType().FullName}.{args.Method.Name}"; + func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}"; if (functions.ContainsKey(func_name)) { @@ -44,13 +44,13 @@ namespace Tensorflow.Graphs } else { - originalInputs = new Tensors(args.Arguments.Length); + originalInputs = new Tensors(); // convert args to placeholder for (var i = 0; i < args.Arguments.Length; i++) { if (args.Arguments[i] is EagerTensor tensor) { - originalInputs[i] = tensor; + originalInputs.Add(tensor); args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs"); } } diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index 5eedbd8b..fa84c62b 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -16,16 +16,23 @@ namespace Tensorflow.Graphs Graph outer_graph; public Graph OuterGraph => outer_graph; - string func_name; - // _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); IntPtr func_handle; - public string FuncName => func_name; + public string FuncName => _graph_key; public Tensors Inputs { get; set; } public Tensors Outputs { get; set; } public Dictionary Attrs { get; set; } + public Dictionary _captures + = new Dictionary(); + + public Tensor[] external_captures() + => _captures.Select(x => x.Value.Item1).ToArray(); + + public Tensor[] internal_captures() + => _captures.Select(x => x.Value.Item2).ToArray(); + // new Dictionary _captures = new Dictionary(); // public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); @@ -35,7 +42,7 @@ namespace Tensorflow.Graphs public FuncGraph(string name) : base() { outer_graph = ops.get_default_graph(); - func_name = name; + _graph_key = name; tf.Context.graph_mode(); as_default(); @@ -44,7 +51,7 @@ namespace Tensorflow.Graphs public FuncGraph(IntPtr handle, string name, Dictionary attrs) : base() { outer_graph = ops.get_default_graph(); - func_name = name; + _graph_key = name; Attrs = attrs; // Will to test if FuncGraph has memory leak // c_api.TF_DeleteGraph(_handle); @@ -60,7 +67,7 @@ namespace Tensorflow.Graphs { using var status = new Status(); func_handle = c_api.TF_GraphToFunction(_handle, - func_name, + _graph_key, false, opers.Length, opers.Select(x => (IntPtr)x).ToArray(), @@ -82,7 +89,7 @@ namespace Tensorflow.Graphs c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, status.Handle); status.Check(true); - func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); + _graph_key = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); Inputs = inputs; // mark_as_return @@ -131,7 +138,7 @@ namespace Tensorflow.Graphs Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) { Tensor placeholder = null; - if (!_captures.Contains(tensor.Id)) + if (!_captures.ContainsKey(tensor.Id)) { placeholder = _create_substitute_placeholder(tensor, name: name, @@ -141,7 +148,7 @@ namespace Tensorflow.Graphs } else { - placeholder = (((Tensor, Tensor))_captures[tensor.Id]).Item2; + placeholder = _captures[tensor.Id].Item2; } BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => @@ -160,7 +167,7 @@ namespace Tensorflow.Graphs void add_capture(Tensor tensor, Tensor placeholder) { - _captures[tensor.Id] = (tensor, placeholder); + _captures.Add(tensor.Id, (tensor, placeholder)); if (Inputs == null) Inputs = new Tensors(placeholder); else diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 5ea2cc07..01415d65 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -87,7 +87,7 @@ namespace Tensorflow private List _unfeedable_tensors = new List(); public string _name_stack = ""; - private string _graph_key; + protected string _graph_key; public string graph_key => _graph_key; public string _last_loss_reduction; public bool _is_loss_scaled_by_optimizer { get; set; } @@ -552,23 +552,5 @@ namespace Tensorflow { return graph._handle; } - - public OrderedDictionary _captures => new OrderedDictionary(); - - public Tensor[] external_captures() - { - Tensor[] captures = new Tensor[_captures.Count]; - ICollection inner = _captures.Keys; // c[0] - inner.CopyTo(captures, 0); - return captures; - } - - public Tensor[] internal_captures() - { - Tensor[] captures = new Tensor[_captures.Count]; - ICollection inner = _captures.Values; // c[1] - inner.CopyTo(captures, 0); - return captures; - } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index 6086e142..82f51f58 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -15,7 +15,7 @@ namespace Tensorflow /// public class Tensors : IEnumerable { - Tensor[] items; + List items = new List(); public TF_DataType dtype => items.First().dtype; public TensorShape shape => items.First().TensorShape; @@ -23,7 +23,7 @@ namespace Tensorflow public Graph graph => items.First().graph; public bool IsEagerTensor => items.First().IsEagerTensor; public bool IsList { get; set; } - public int Length => items.Length; + public int Length => items.Count(); public Tensor this[int index] { @@ -40,17 +40,12 @@ namespace Tensorflow public Tensors(params Tensor[] tensors) { - items = tensors; + items.AddRange(tensors); } public Tensors(NDArray nd) { - items = new[] { ops.convert_to_tensor(nd) }; - } - - public Tensors(int count) - { - items = new Tensor[count]; + items.Add(ops.convert_to_tensor(nd)); } public IEnumerator GetEnumerator() @@ -59,6 +54,9 @@ namespace Tensorflow yield return tensor; } + public void Add(Tensor tensor) + => items.Add(tensor); + IEnumerator IEnumerable.GetEnumerator() { throw new NotImplementedException(); @@ -80,11 +78,11 @@ namespace Tensorflow => tensors.FirstOrDefault(); public static implicit operator Tensor[](Tensors tensors) - => tensors.items; + => tensors.items.ToArray(); public override string ToString() - => items.Length == 1 + => items.Count() == 1 ? items.First().ToString() - : items.Length + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); + : items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); } } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index e2697aeb..449a978b 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -20,6 +20,7 @@ using System.Collections.Generic; using System.Linq; using System.Text; using Tensorflow.Eager; +using Tensorflow.Graphs; using static Tensorflow.Binding; namespace Tensorflow @@ -392,14 +393,14 @@ would not be rank 1.", tensor.op.get_attr("axis"))); } else if (tensor.op.type == "Placeholder" && tensor.op.graph.building_function && - hasattr(tensor.op.graph, "internal_captures")) + tensor.op.graph is FuncGraph func_graph) { int i = 0; - foreach (Tensor capture in tensor.op.graph.internal_captures()) + foreach (Tensor capture in func_graph.internal_captures()) { if (capture.GetType() == typeof(Tensor)) { - var external_capture = tensor.op.graph.external_captures()[i]; + var external_capture = func_graph.external_captures()[i]; return constant_value_as_shape(external_capture); } diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index a02f46d2..6c67f109 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -337,10 +337,10 @@ namespace Tensorflow.Keras.Engine var layer_inputs = node.MapArguments(tensor_dict); - tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}"); + tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); var outputs = node.Layer.Apply(layer_inputs, is_training: training); foreach (var output in outputs.Where(x => x != null)) - tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); + tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); // Update tensor_dict for next input foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) tensor_dict[x_id] = new Queue(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); diff --git a/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs b/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs index 3da828de..d0bf36e6 100644 --- a/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs +++ b/src/TensorFlowNET.Keras/Engine/TensorFlowOpLayer.cs @@ -60,7 +60,7 @@ namespace Tensorflow.Keras.Engine tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); - return op.output; + return op.outputs; } } }