| @@ -139,7 +139,7 @@ namespace Tensorflow.Functions | |||
| "executor_type", "", | |||
| "config_proto", tf.Context.FunctionCallOptions.config_proto_serialized() | |||
| }; | |||
| return tf.Runner.Execute(tf.Context, func_graph.FuncName, 1, args, attrs); | |||
| return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs); | |||
| } | |||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| @@ -61,7 +62,7 @@ namespace Tensorflow.Functions | |||
| processed_args.add(arg); | |||
| input_index += 1; | |||
| } | |||
| tf.Logger.Debug($"Invoke backward function: {backward.Name}"); | |||
| return backward.CallFlat(processed_args.ToArray(), outputs); | |||
| }; | |||
| @@ -91,6 +92,14 @@ namespace Tensorflow.Functions | |||
| grad_ys: gradients_wrt_outputs.ToArray(), | |||
| src_graph: _func_graph); | |||
| var captures_from_forward = backwards_graph.external_captures() | |||
| .Where(x => !x.IsEagerTensor && x.graph == _func_graph) | |||
| .ToArray(); | |||
| foreach(var capture in captures_from_forward) | |||
| { | |||
| _func_graph.Outputs.Add(capture); | |||
| } | |||
| var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | |||
| var backward_function_attr = new Dictionary<string, string>(); | |||
| backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Graphs | |||
| public override void OnEntry(MethodExecutionArgs args) | |||
| { | |||
| func_name = $"autograph_{args.Instance.GetType().FullName}.{args.Method.Name}"; | |||
| func_name = $"autograph_{args.Instance.GetHashCode()}.{args.Method.Name}"; | |||
| if (functions.ContainsKey(func_name)) | |||
| { | |||
| @@ -44,13 +44,13 @@ namespace Tensorflow.Graphs | |||
| } | |||
| else | |||
| { | |||
| originalInputs = new Tensors(args.Arguments.Length); | |||
| originalInputs = new Tensors(); | |||
| // convert args to placeholder | |||
| for (var i = 0; i < args.Arguments.Length; i++) | |||
| { | |||
| if (args.Arguments[i] is EagerTensor tensor) | |||
| { | |||
| originalInputs[i] = tensor; | |||
| originalInputs.Add(tensor); | |||
| args.Arguments[i] = tf.placeholder(tensor.dtype, shape: tensor.TensorShape, name: "inputs"); | |||
| } | |||
| } | |||
| @@ -16,16 +16,23 @@ namespace Tensorflow.Graphs | |||
| Graph outer_graph; | |||
| public Graph OuterGraph => outer_graph; | |||
| string func_name; | |||
| // _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||
| IntPtr func_handle; | |||
| public string FuncName => func_name; | |||
| public string FuncName => _graph_key; | |||
| public Tensors Inputs { get; set; } | |||
| public Tensors Outputs { get; set; } | |||
| public Dictionary<string, string> Attrs { get; set; } | |||
| public Dictionary<long, (Tensor, Tensor)> _captures | |||
| = new Dictionary<long, (Tensor, Tensor)>(); | |||
| public Tensor[] external_captures() | |||
| => _captures.Select(x => x.Value.Item1).ToArray(); | |||
| public Tensor[] internal_captures() | |||
| => _captures.Select(x => x.Value.Item2).ToArray(); | |||
| // new Dictionary<long, (Tensor, Tensor)> _captures = new Dictionary<long, (Tensor, Tensor)>(); | |||
| // public new Tensor[] external_captures => _captures.Values.Select(x => x.Item1).ToArray(); | |||
| @@ -35,7 +42,7 @@ namespace Tensorflow.Graphs | |||
| public FuncGraph(string name) : base() | |||
| { | |||
| outer_graph = ops.get_default_graph(); | |||
| func_name = name; | |||
| _graph_key = name; | |||
| tf.Context.graph_mode(); | |||
| as_default(); | |||
| @@ -44,7 +51,7 @@ namespace Tensorflow.Graphs | |||
| public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | |||
| { | |||
| outer_graph = ops.get_default_graph(); | |||
| func_name = name; | |||
| _graph_key = name; | |||
| Attrs = attrs; | |||
| // Will to test if FuncGraph has memory leak | |||
| // c_api.TF_DeleteGraph(_handle); | |||
| @@ -60,7 +67,7 @@ namespace Tensorflow.Graphs | |||
| { | |||
| using var status = new Status(); | |||
| func_handle = c_api.TF_GraphToFunction(_handle, | |||
| func_name, | |||
| _graph_key, | |||
| false, | |||
| opers.Length, | |||
| opers.Select(x => (IntPtr)x).ToArray(), | |||
| @@ -82,7 +89,7 @@ namespace Tensorflow.Graphs | |||
| c_api.TFE_ContextAddFunction(tf.Context.Handle, func_handle, status.Handle); | |||
| status.Check(true); | |||
| func_name = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | |||
| _graph_key = c_api.StringPiece(c_api.TF_FunctionName(func_handle)); | |||
| Inputs = inputs; | |||
| // mark_as_return | |||
| @@ -131,7 +138,7 @@ namespace Tensorflow.Graphs | |||
| Tensor _capture_helper(Tensor tensor, string name, TensorShape shape = null) | |||
| { | |||
| Tensor placeholder = null; | |||
| if (!_captures.Contains(tensor.Id)) | |||
| if (!_captures.ContainsKey(tensor.Id)) | |||
| { | |||
| placeholder = _create_substitute_placeholder(tensor, | |||
| name: name, | |||
| @@ -141,7 +148,7 @@ namespace Tensorflow.Graphs | |||
| } | |||
| else | |||
| { | |||
| placeholder = (((Tensor, Tensor))_captures[tensor.Id]).Item2; | |||
| placeholder = _captures[tensor.Id].Item2; | |||
| } | |||
| BackwardFunction _backward_function_wrapper = (output_grads, unneeded_gradients) => | |||
| @@ -160,7 +167,7 @@ namespace Tensorflow.Graphs | |||
| void add_capture(Tensor tensor, Tensor placeholder) | |||
| { | |||
| _captures[tensor.Id] = (tensor, placeholder); | |||
| _captures.Add(tensor.Id, (tensor, placeholder)); | |||
| if (Inputs == null) | |||
| Inputs = new Tensors(placeholder); | |||
| else | |||
| @@ -87,7 +87,7 @@ namespace Tensorflow | |||
| private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | |||
| public string _name_stack = ""; | |||
| private string _graph_key; | |||
| protected string _graph_key; | |||
| public string graph_key => _graph_key; | |||
| public string _last_loss_reduction; | |||
| public bool _is_loss_scaled_by_optimizer { get; set; } | |||
| @@ -552,23 +552,5 @@ namespace Tensorflow | |||
| { | |||
| return graph._handle; | |||
| } | |||
| public OrderedDictionary _captures => new OrderedDictionary(); | |||
| public Tensor[] external_captures() | |||
| { | |||
| Tensor[] captures = new Tensor[_captures.Count]; | |||
| ICollection inner = _captures.Keys; // c[0] | |||
| inner.CopyTo(captures, 0); | |||
| return captures; | |||
| } | |||
| public Tensor[] internal_captures() | |||
| { | |||
| Tensor[] captures = new Tensor[_captures.Count]; | |||
| ICollection inner = _captures.Values; // c[1] | |||
| inner.CopyTo(captures, 0); | |||
| return captures; | |||
| } | |||
| } | |||
| } | |||
| @@ -15,7 +15,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public class Tensors : IEnumerable<Tensor> | |||
| { | |||
| Tensor[] items; | |||
| List<Tensor> items = new List<Tensor>(); | |||
| public TF_DataType dtype => items.First().dtype; | |||
| public TensorShape shape => items.First().TensorShape; | |||
| @@ -23,7 +23,7 @@ namespace Tensorflow | |||
| public Graph graph => items.First().graph; | |||
| public bool IsEagerTensor => items.First().IsEagerTensor; | |||
| public bool IsList { get; set; } | |||
| public int Length => items.Length; | |||
| public int Length => items.Count(); | |||
| public Tensor this[int index] | |||
| { | |||
| @@ -40,17 +40,12 @@ namespace Tensorflow | |||
| public Tensors(params Tensor[] tensors) | |||
| { | |||
| items = tensors; | |||
| items.AddRange(tensors); | |||
| } | |||
| public Tensors(NDArray nd) | |||
| { | |||
| items = new[] { ops.convert_to_tensor(nd) }; | |||
| } | |||
| public Tensors(int count) | |||
| { | |||
| items = new Tensor[count]; | |||
| items.Add(ops.convert_to_tensor(nd)); | |||
| } | |||
| public IEnumerator<Tensor> GetEnumerator() | |||
| @@ -59,6 +54,9 @@ namespace Tensorflow | |||
| yield return tensor; | |||
| } | |||
| public void Add(Tensor tensor) | |||
| => items.Add(tensor); | |||
| IEnumerator IEnumerable.GetEnumerator() | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -80,11 +78,11 @@ namespace Tensorflow | |||
| => tensors.FirstOrDefault(); | |||
| public static implicit operator Tensor[](Tensors tensors) | |||
| => tensors.items; | |||
| => tensors.items.ToArray(); | |||
| public override string ToString() | |||
| => items.Length == 1 | |||
| => items.Count() == 1 | |||
| ? items.First().ToString() | |||
| : items.Length + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); | |||
| : items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name)); | |||
| } | |||
| } | |||
| @@ -20,6 +20,7 @@ using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -392,14 +393,14 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||
| } | |||
| else if (tensor.op.type == "Placeholder" && | |||
| tensor.op.graph.building_function && | |||
| hasattr(tensor.op.graph, "internal_captures")) | |||
| tensor.op.graph is FuncGraph func_graph) | |||
| { | |||
| int i = 0; | |||
| foreach (Tensor capture in tensor.op.graph.internal_captures()) | |||
| foreach (Tensor capture in func_graph.internal_captures()) | |||
| { | |||
| if (capture.GetType() == typeof(Tensor)) | |||
| { | |||
| var external_capture = tensor.op.graph.external_captures()[i]; | |||
| var external_capture = func_graph.external_captures()[i]; | |||
| return constant_value_as_shape(external_capture); | |||
| } | |||
| @@ -337,10 +337,10 @@ namespace Tensorflow.Keras.Engine | |||
| var layer_inputs = node.MapArguments(tensor_dict); | |||
| tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name}"); | |||
| tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}"); | |||
| var outputs = node.Layer.Apply(layer_inputs, is_training: training); | |||
| foreach (var output in outputs.Where(x => x != null)) | |||
| tf.Logger.Debug($"{depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); | |||
| tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}"); | |||
| // Update tensor_dict for next input | |||
| foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | |||
| tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y)); | |||
| @@ -60,7 +60,7 @@ namespace Tensorflow.Keras.Engine | |||
| tf.Runner.RecordGradient(op_type, op.inputs._inputs, null, op.outputs); | |||
| return op.output; | |||
| return op.outputs; | |||
| } | |||
| } | |||
| } | |||