| @@ -139,7 +139,7 @@ namespace Tensorflow.Functions | |||||
| "executor_type", "", | "executor_type", "", | ||||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | ||||
| }; | }; | ||||
| return tf.Runner.Execute(tf.Context, func_graph.FuncName, 1, args, attrs); | |||||
| return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||||
| } | } | ||||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -61,7 +62,7 @@ namespace Tensorflow.Functions | |||||
| processed_args.add(arg); | processed_args.add(arg); | ||||
| input_index += 1; | input_index += 1; | ||||
| } | } | ||||
| tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||||
| return backward.CallFlat(processed_args.ToArray(), outputs); | return backward.CallFlat(processed_args.ToArray(), outputs); | ||||
| }; | }; | ||||
| @@ -91,6 +92,14 @@ namespace Tensorflow.Functions | |||||
| grad_ys: gradients_wrt_outputs.ToArray(), | grad_ys: gradients_wrt_outputs.ToArray(), | ||||
| src_graph: _func_graph); | src_graph: _func_graph); | ||||
| var captures_from_forward = backwards_graph.external_captures() | |||||
| .Where(x => !x.IsEagerTensor && x.graph == _func_graph) | |||||
| .ToArray(); | |||||
| foreach(var capture in captures_from_forward) | |||||
| { | |||||
| _func_graph.Outputs.Add(capture); | |||||
| } | |||||
| var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | ||||
| var backward_function_attr = new Dictionary<string, string>(); | var backward_function_attr = new Dictionary<string, string>(); | ||||
| backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | ||||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Graphs | |||||
| public override void OnEntry(MethodExecutionArgs args) | public override void OnEntry(MethodExecutionArgs args) | ||||
| { | { | ||||
| func_name = $"autograph_{args.Instance.GetType().FullName}.{args.Method.Name}"; | |||||
| func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}"; | |||||
| if (functions.ContainsKey(func_name)) | if (functions.ContainsKey(func_name)) | ||||
| { | { | ||||
| @@ -44,13 +44,13 @@ namespace Tensorflow.Graphs | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| originalInputs = new Tensors(args.Arguments.Length); | |||||
| originalInputs = new Tensors(); | |||||
| // convert args to placeholder | // convert args to placeholder | ||||
| for (var i = 0; i < args.Arguments.Length; i++) | for (var i = 0; i < args.Arguments.Length; i++) | ||||
| { | { | ||||
| if (args.Arguments[i] is EagerTensor tensor) | if (args.Arguments[i] is EagerTensor tensor) | ||||
| { | { | ||||
| originalInputs[i] = tensor; | |||||
| originalInputs.Add(tensor); | |||||
| args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs"); | args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs"); | ||||
| } | } | ||||
| } | } | ||||
| @@ -16,16 +16,23 @@ namespace Tensorflow.Graphs | |||||
| Graph outer_graph; | Graph outer_graph; | ||||
| public Graph OuterGraph => outer_graph; | public Graph OuterGraph => outer_graph; | ||||
| string func_name; | |||||
| // _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | // _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | ||||
| IntPtr func_handle; | IntPtr func_handle; | ||||
| public string FuncName => func_name; | |||||
| public string FuncName => _graph_key; | |||||
| public Tensors Inputs { get; set; } | public Tensors Inputs { get; set; } | ||||
| public Tensors Outputs { get; set; } | public Tensors Outputs { get; set; } | ||||
| public Dictionary<string, string> Attrs { get; set; } | public Dictionary<string, string> Attrs { get; set; } | ||||
| public Dictionary<long, (Tensor, Tensor)> _captures | |||||
| = new Dictionary<long, (Tensor, Tensor)>(); | |||||
| public Tensor[] external_captures() | |||||
| => _captures.Select(x => x.Value.Item1).ToArray(); | |||||
| public Tensor[] internal_captures() | |||||
| => _captures.Select(x => x.Value.Item2).ToArray(); | |||||
| // new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | // new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | ||||
| // public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); | // public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); | ||||
| @@ -35,7 +42,7 @@ namespace Tensorflow.Graphs | |||||
| public FuncGraph(string name) : base() | public FuncGraph(string name) : base() | ||||
| { | { | ||||
| outer_graph = ops.get_default_graph(); | outer_graph = ops.get_default_graph(); | ||||
| func_name = name; | |||||
| _graph_key = name; | |||||
| tf.Context.graph_mode(); | tf.Context.graph_mode(); | ||||
| as_default(); | as_default(); | ||||
| @@ -44,7 +51,7 @@ namespace Tensorflow.Graphs | |||||
| public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | ||||
| { | { | ||||
| outer_graph = ops.get_default_graph(); | outer_graph = ops.get_default_graph(); | ||||
| func_name = name; | |||||
| _graph_key = name; | |||||
| Attrs = attrs; | Attrs = attrs; | ||||
| // Will to test if FuncGraph has memory leak | // Will to test if FuncGraph has memory leak | ||||
| // c_api.TF_DeleteGraph(_handle); | // c_api.TF_DeleteGraph(_handle); | ||||
| @@ -60,7 +67,7 @@ namespace Tensorflow.Graphs | |||||
| { | { | ||||
| using var status = new Status(); | using var status = new Status(); | ||||
| func_handle = c_api.TF_GraphToFunction(_handle, | func_handle = c_api.TF_GraphToFunction(_handle, | ||||
| func_name, | |||||
| _graph_key, | |||||
| false, | false, | ||||
| opers.Length, | opers.Length, | ||||
| opers.Select(x => (IntPtr)x).ToArray(), | opers.Select(x => (IntPtr)x).ToArray(), | ||||
| @@ -82,7 +89,7 @@ namespace Tensorflow.Graphs | |||||
| c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, status.Handle); | c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, status.Handle); | ||||
| status.Check(true); | status.Check(true); | ||||
| func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | |||||
| _graph_key = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | |||||
| Inputs = inputs; | Inputs = inputs; | ||||
| // mark_as_return | // mark_as_return | ||||
| @@ -131,7 +138,7 @@ namespace Tensorflow.Graphs | |||||
| Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | ||||
| { | { | ||||
| Tensor placeholder = null; | Tensor placeholder = null; | ||||
| if (!_captures.Contains(tensor.Id)) | |||||
| if (!_captures.ContainsKey(tensor.Id)) | |||||
| { | { | ||||
| placeholder = _create_substitute_placeholder(tensor, | placeholder = _create_substitute_placeholder(tensor, | ||||
| name: name, | name: name, | ||||
| @@ -141,7 +148,7 @@ namespace Tensorflow.Graphs | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| placeholder = (((Tensor, Tensor))_captures[tensor.Id]).Item2; | |||||
| placeholder = _captures[tensor.Id].Item2; | |||||
| } | } | ||||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | ||||
| @@ -160,7 +167,7 @@ namespace Tensorflow.Graphs | |||||
| void add_capture(Tensor tensor, Tensor placeholder) | void add_capture(Tensor tensor, Tensor placeholder) | ||||
| { | { | ||||
| _captures[tensor.Id] = (tensor, placeholder); | |||||
| _captures.Add(tensor.Id, (tensor, placeholder)); | |||||
| if (Inputs == null) | if (Inputs == null) | ||||
| Inputs = new Tensors(placeholder); | Inputs = new Tensors(placeholder); | ||||
| else | else | ||||
| @@ -87,7 +87,7 @@ namespace Tensorflow | |||||
| private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | ||||
| public string _name_stack = ""; | public string _name_stack = ""; | ||||
| private string _graph_key; | |||||
| protected string _graph_key; | |||||
| public string graph_key => _graph_key; | public string graph_key => _graph_key; | ||||
| public string _last_loss_reduction; | public string _last_loss_reduction; | ||||
| public bool _is_loss_scaled_by_optimizer { get; set; } | public bool _is_loss_scaled_by_optimizer { get; set; } | ||||
| @@ -552,23 +552,5 @@ namespace Tensorflow | |||||
| { | { | ||||
| return graph._handle; | return graph._handle; | ||||
| } | } | ||||
| public OrderedDictionary _captures => new OrderedDictionary(); | |||||
| public Tensor[] external_captures() | |||||
| { | |||||
| Tensor[] captures = new Tensor[_captures.Count]; | |||||
| ICollection inner = _captures.Keys; // c[0] | |||||
| inner.CopyTo(captures, 0); | |||||
| return captures; | |||||
| } | |||||
| public Tensor[] internal_captures() | |||||
| { | |||||
| Tensor[] captures = new Tensor[_captures.Count]; | |||||
| ICollection inner = _captures.Values; // c[1] | |||||
| inner.CopyTo(captures, 0); | |||||
| return captures; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class Tensors : IEnumerable<Tensor> | public class Tensors : IEnumerable<Tensor> | ||||
| { | { | ||||
| Tensor[] items; | |||||
| List<Tensor> items = new List<Tensor>(); | |||||
| public TF_DataType dtype => items.First().dtype; | public TF_DataType dtype => items.First().dtype; | ||||
| public TensorShape shape => items.First().TensorShape; | public TensorShape shape => items.First().TensorShape; | ||||
| @@ -23,7 +23,7 @@ namespace Tensorflow | |||||
| public Graph graph => items.First().graph; | public Graph graph => items.First().graph; | ||||
| public bool IsEagerTensor => items.First().IsEagerTensor; | public bool IsEagerTensor => items.First().IsEagerTensor; | ||||
| public bool IsList { get; set; } | public bool IsList { get; set; } | ||||
| public int Length => items.Length; | |||||
| public int Length => items.Count(); | |||||
| public Tensor this[int index] | public Tensor this[int index] | ||||
| { | { | ||||
| @@ -40,17 +40,12 @@ namespace Tensorflow | |||||
| public Tensors(params Tensor[] tensors) | public Tensors(params Tensor[] tensors) | ||||
| { | { | ||||
| items = tensors; | |||||
| items.AddRange(tensors); | |||||
| } | } | ||||
| public Tensors(NDArray nd) | public Tensors(NDArray nd) | ||||
| { | { | ||||
| items = new[] { ops.convert_to_tensor(nd) }; | |||||
| } | |||||
| public Tensors(int count) | |||||
| { | |||||
| items = new Tensor[count]; | |||||
| items.Add(ops.convert_to_tensor(nd)); | |||||
| } | } | ||||
| public IEnumerator<Tensor> GetEnumerator() | public IEnumerator<Tensor> GetEnumerator() | ||||
| @@ -59,6 +54,9 @@ namespace Tensorflow | |||||
| yield return tensor; | yield return tensor; | ||||
| } | } | ||||
| public void Add(Tensor tensor) | |||||
| => items.Add(tensor); | |||||
| IEnumerator IEnumerable.GetEnumerator() | IEnumerator IEnumerable.GetEnumerator() | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| @@ -80,11 +78,11 @@ namespace Tensorflow | |||||
| => tensors.FirstOrDefault(); | => tensors.FirstOrDefault(); | ||||
| public static implicit operator Tensor[](Tensors tensors) | public static implicit operator Tensor[](Tensors tensors) | ||||
| => tensors.items; | |||||
| => tensors.items.ToArray(); | |||||
| public override string ToString() | public override string ToString() | ||||
| => items.Length == 1 | |||||
| => items.Count() == 1 | |||||
| ? items.First().ToString() | ? items.First().ToString() | ||||
| : items.Length + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); | |||||
| : items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); | |||||
| } | } | ||||
| } | } | ||||
| @@ -20,6 +20,7 @@ using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Graphs; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -392,14 +393,14 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
| } | } | ||||
| else if (tensor.op.type == "Placeholder" && | else if (tensor.op.type == "Placeholder" && | ||||
| tensor.op.graph.building_function && | tensor.op.graph.building_function && | ||||
| hasattr(tensor.op.graph, "internal_captures")) | |||||
| tensor.op.graph is FuncGraph func_graph) | |||||
| { | { | ||||
| int i = 0; | int i = 0; | ||||
| foreach (Tensor capture in tensor.op.graph.internal_captures()) | |||||
| foreach (Tensor capture in func_graph.internal_captures()) | |||||
| { | { | ||||
| if (capture.GetType() == typeof(Tensor)) | if (capture.GetType() == typeof(Tensor)) | ||||
| { | { | ||||
| var external_capture = tensor.op.graph.external_captures()[i]; | |||||
| var external_capture = func_graph.external_captures()[i]; | |||||
| return constant_value_as_shape(external_capture); | return constant_value_as_shape(external_capture); | ||||
| } | } | ||||
| @@ -337,10 +337,10 @@ namespace Tensorflow.Keras.Engine | |||||
| var layer_inputs = node.MapArguments(tensor_dict); | var layer_inputs = node.MapArguments(tensor_dict); | ||||
| tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}"); | |||||
| tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); | |||||
| var outputs = node.Layer.Apply(layer_inputs, is_training: training); | var outputs = node.Layer.Apply(layer_inputs, is_training: training); | ||||
| foreach (var output in outputs.Where(x => x != null)) | foreach (var output in outputs.Where(x => x != null)) | ||||
| tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); | |||||
| tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); | |||||
| // Update tensor_dict for next input | // Update tensor_dict for next input | ||||
| foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | ||||
| tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | ||||
| @@ -60,7 +60,7 @@ namespace Tensorflow.Keras.Engine | |||||
| tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | ||||
| return op.output; | |||||
| return op.outputs; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||