From e75a1116205b9123897f8ba312505345495cca8c Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 3 Jan 2021 20:45:08 -0600 Subject: [PATCH] lift_to_graph --- src/TensorFlowNET.Core/Binding.Util.cs | 18 +- .../Functions/ConcreteFunction.cs | 4 +- .../Functions/TapeGradientFunctions.cs | 4 +- src/TensorFlowNET.Core/Graphs/FuncGraph.cs | 11 +- .../Graphs/SubGraphUtility.cs | 175 ++++++++++++++++++ .../Operations/gen_math_ops.cs | 16 -- src/TensorFlowNET.Core/Operations/math_ops.cs | 34 ++++ .../Tensors/Tensor.Operators.cs | 11 +- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 24 ++- src/TensorFlowNET.Keras/BackendImpl.cs | 42 ++++- 10 files changed, 290 insertions(+), 49 deletions(-) create mode 100644 src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 33fdad7c..62ba0bbd 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -46,6 +46,15 @@ namespace Tensorflow } } + public static void difference_update(this IList list, IList list2) + { + foreach(var el in list2) + { + if (list.Contains(el)) + list.Remove(el); + } + } + public static void add(this IList list, T element) => list.Add(element); @@ -158,6 +167,13 @@ namespace Tensorflow return Enumerable.Range(start, end - start); } + public static IEnumerable reversed(IList values) + { + var len = values.Count; + for (int i = len - 1; i >= 0; i--) + yield return values[i]; + } + public static T New() where T : ITensorFlowObject, new() { var instance = new T(); @@ -284,7 +300,7 @@ namespace Tensorflow for (int i = 0; i < len; i++) yield return (i, values[i]); } - + public static IEnumerable<(int, T)> enumerate(IEnumerable values, int start = 0, int step = 1) { int i = 0; diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index e9203878..45fa3420 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -14,7 +14,7 @@ namespace Tensorflow.Functions { IntPtr _handle; FuncGraph func_graph; - public Tensor[] CapturedInputs => func_graph.external_captures(); + public Tensor[] CapturedInputs => func_graph.external_captures; public string Name { @@ -37,7 +37,7 @@ namespace Tensorflow.Functions func_graph.as_default(); } - public ConcreteFunction(FuncGraph graph, Dictionary attrs) + public ConcreteFunction(FuncGraph graph, Dictionary attrs = null) { func_graph = graph; diff --git a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs index 4559fc5d..5dd1a8ae 100644 --- a/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs +++ b/src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs @@ -93,7 +93,7 @@ namespace Tensorflow.Functions grad_ys: gradients_wrt_outputs.ToArray(), src_graph: _func_graph); - var captures_from_forward = backwards_graph.external_captures() + var captures_from_forward = backwards_graph.external_captures .Where(x => !x.IsEagerTensor && x.graph == _func_graph) .ToArray(); foreach(var capture in captures_from_forward) @@ -105,7 +105,7 @@ namespace Tensorflow.Functions var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; var backward_function_attr = new Dictionary(); backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; - gradients_wrt_outputs.append(backwards_graph.internal_captures()); + gradients_wrt_outputs.append(backwards_graph.internal_captures); backwards_graph.Inputs = gradients_wrt_outputs; backwards_graph.Outputs = gradients_wrt_inputs; diff --git a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs index bc2eebb4..37f03267 100644 --- a/src/TensorFlowNET.Core/Graphs/FuncGraph.cs +++ b/src/TensorFlowNET.Core/Graphs/FuncGraph.cs @@ -21,15 +21,20 @@ namespace Tensorflow.Graphs public Tensors Outputs { get; set; } = new Tensors(); public Dictionary Attrs { get; set; } - public Dictionary _captures + Dictionary _captures = new Dictionary(); - public Tensor[] external_captures() + public Tensor[] external_captures => _captures.Select(x => x.Value.Item1).ToArray(); + public (Tensor, Tensor)[] captures + => _captures.Values.Select(x => x).ToArray(); - public Tensor[] internal_captures() + public Tensor[] internal_captures => _captures.Select(x => x.Value.Item2).ToArray(); + public Tensor[] captured_inputs + => external_captures; + /// /// Construct a new FuncGraph. /// diff --git a/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs b/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs new file mode 100644 index 00000000..7bc7abe4 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs @@ -0,0 +1,175 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Linq; +using static Tensorflow.Binding; + +namespace Tensorflow.Graphs +{ + public class SubGraphUtility + { + /// + /// Copies the tensor and all its inputs recursively to the outer graph. + /// + /// + /// + /// + /// + /// + /// + public static Dictionary lift_to_graph(Tensors init_tensors, + FuncGraph graph, + List sources, + bool add_sources = false, + bool handle_captures = false, + Graph base_graph = null, + Dictionary op_map = null) + { + base_graph = base_graph ?? init_tensors[0].graph; + op_map = op_map ?? new Dictionary(); + var visited_ops = sources.Select(x => x.op).ToList(); + foreach (var init_tensor in init_tensors) + { + var src = map_subgraph(init_tensor, sources, visited_ops, add_sources); + sources.AddRange(src); + } + + var ops_to_copy = new List(); + var marked_ops = new List(); + var ops_to_visit = new Stack(init_tensors.Select(x => x.op)); + var unvisited_ops = new List(ops_to_visit.ToList()); + while (unvisited_ops.Count > 0) + { + while(ops_to_visit.Count > 0) + { + var op = ops_to_visit.Pop(); + if (marked_ops.Contains(op)) + continue; + marked_ops.Add(op); + ops_to_copy.append(op); + foreach(var inp in op.inputs) + { + + } + } + // difference_update + unvisited_ops.difference_update(marked_ops); + if (unvisited_ops.Count > 0) + ops_to_visit.Push(unvisited_ops.Last()); + } + + // When lifting from one FuncGraph to another, we will need to capture the + // relevant tensors as well. + var inverse_captures = new Dictionary(); + Tensor[] internal_captures = null; + if (base_graph is FuncGraph base_func_graph) + { + var captures = base_func_graph.captures; + foreach (var (external_capture, internal_capture) in captures) + inverse_captures[internal_capture] = external_capture; + internal_captures = base_func_graph.internal_captures; + } + + graph.as_default(); + var source_ops = new List(); + // Add the sources in the same order as the original graph. + foreach (var s in internal_captures) + { + if (sources.Contains(s)) + { + sources.Remove(s); + source_ops.Add(s.op); + _copy_source(s: s, + graph: graph, + op_map: op_map, + handle_captures: handle_captures, + inverse_captures: inverse_captures, + base_graph: base_graph); + } + } + + foreach(var op in reversed(ops_to_copy)) + { + if (source_ops.Contains(op) || op_map.ContainsKey(op)) + continue; + _copy_non_source(op, graph, op_map, base_graph); + } + + return op_map; + } + + static void _copy_source(Tensor s, + FuncGraph graph, + Dictionary op_map, + bool handle_captures, + Dictionary inverse_captures, + Graph base_graph) + { + Tensor copied_placeholder = null; + if (handle_captures && inverse_captures.ContainsKey(s)) + copied_placeholder = graph.capture(inverse_captures[s], name: s.op.name); + else + throw new NotImplementedException(""); + op_map[s] = copied_placeholder; + // Add an entry for the op of the source tensor so that if there are any nodes + // depending on that op via control dependencies it can work correctly. + op_map[s.op] = copied_placeholder.op; + } + + static void _copy_non_source(Operation op, FuncGraph graph, Dictionary op_map, Graph base_graph) + { + Operation copied_op = null; + var copied_inputs = new Tensors(); + tf_with(ops.control_dependencies(new object[] { op }), delegate + { + // Create a new op in the destination graph if it doesn't exist before. + var attrs = new Dictionary(); + foreach (var attr_def in op.node_def.Attr) + attrs[attr_def.Key] = attr_def.Value; + + copied_op = graph.create_op(op.type, + copied_inputs, + dtypes: op.outputs.Select(x => x.dtype).ToArray(), + attrs: attrs, + name: op.name); + }); + op_map[op] = copied_op; + foreach (var (i, o) in enumerate(op.outputs)) + op_map[o] = copied_op.outputs[i]; + } + + /// + /// Walk a Graph and capture the subgraph between init_tensor and sources. + /// + /// + /// + public static List map_subgraph(Tensor init_tensor, + List sources, + List visited_ops, + bool add_sources) + { + var ops_to_visit = new Stack(); + ops_to_visit.Push(init_tensor.op); + var extra_sources = new List(); + while (ops_to_visit.Count > 0) + { + var op = ops_to_visit.Pop(); + if (visited_ops.Contains(op)) + continue; + visited_ops.Add(op); + bool should_raise = false; + if (should_raise) + throw new RuntimeError($"Unable to lift tensor {init_tensor.name}."); + if(op.type == "Placeholder") + { + extra_sources.AddRange(op.outputs); + } + foreach(var inp in op.inputs) + { + + } + } + return extra_sources; + } + } +} diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index b40dc2ae..3d64e8b9 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -873,22 +873,6 @@ namespace Tensorflow return _op.output; } - public static Tensor mul(Tensor x, Tensor y, string name = null) - { - if (tf.Context.executing_eagerly()) - { - var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, - "Mul", name, - null, - x, y); - return results[0]; - } - - var _op = tf.OpDefLib._apply_op_helper("Mul", name, args: new { x, y }); - - return _op.output; - } - public static Tensor mul(Tx x, Ty y, string name = null) { if (tf.Context.executing_eagerly()) diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 8db47f1a..2c051992 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -44,6 +44,23 @@ namespace Tensorflow public static Tensor add(Tx x, Ty y, string name = null) => gen_math_ops.add(x, y, name); + public static Tensor add_v2(Tensor x, Tensor y, string name = null) + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("AddV2", name, new { x, y }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "AddV2", name, + null, + x, y).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T") + }; + tf.Runner.RecordGradient("AddV2", op.inputs, attrs, op.outputs); + }, + new Tensors(x, y)); + public static Tensor add_v2(Tx x, Ty y, string name = null) => gen_math_ops.add_v2(x, y, name); @@ -251,6 +268,23 @@ namespace Tensorflow public static Tensor sqrt(Tensor x, string name = null) => gen_math_ops.sqrt(x, name: name); + public static Tensor multiply(Tensor x, Tensor y, string name = null) + => tf.Context.RunInAutoMode2( + () => tf.OpDefLib._apply_op_helper("Mul", name, new { x, y }).output, + () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Mul", name, + null, + x, y).FirstOrDefault(), + (op) => + { + var attrs = new object[] + { + "T", op.get_attr("T") + }; + tf.Runner.RecordGradient("Mul", op.inputs, attrs, op.outputs); + }, + new Tensors(x, y)); + public static Tensor multiply(Tx x, Ty y, string name = null) => gen_math_ops.mul(x, y, name: name); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index 1c394238..95f571c5 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -309,25 +309,19 @@ namespace Tensorflow private static Tensor BinaryOpWrapper(string name, Tx x, Ty y) { TF_DataType dtype = TF_DataType.DtInvalid; - bool switchToGraphModeTemp = !tf.executing_eagerly(); if (x is Tensor tl) { dtype = tl.dtype.as_base_dtype(); - switchToGraphModeTemp = switchToGraphModeTemp || !tl.IsEagerTensor; } if (y is Tensor tr) { dtype = tr.dtype.as_base_dtype(); - switchToGraphModeTemp = switchToGraphModeTemp || !tr.IsEagerTensor; } return tf_with(ops.name_scope(null, name, new { x, y }), scope => { - if (switchToGraphModeTemp) - tf.Context.graph_mode(); - Tensor result; var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); @@ -347,7 +341,7 @@ namespace Tensorflow result = math_ops.truediv(x1, y1, name: scope); break; case "mul": - result = gen_math_ops.mul(x1, y1, name: scope); + result = math_ops.multiply(x1, y1, name: scope); break; case "sub": result = gen_math_ops.sub(x1, y1, name: scope); @@ -359,9 +353,6 @@ namespace Tensorflow throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); } - if (switchToGraphModeTemp) - tf.Context.restore_mode(); - return result; }); } diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index 449a978b..c5e964f6 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -69,27 +69,25 @@ namespace Tensorflow int num_elements = np.prod(shape); var tensor_dtype = tensor.Dtype.as_numpy_dtype(); - if (tensor.TensorContent.Length > 0) + if (shape.Length > 0 && tensor.TensorContent.Length > 0) { return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype).reshape(shape); } else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) -#pragma warning disable CS0642 // Possible mistaken empty statement - ; -#pragma warning restore CS0642 // Possible mistaken empty statement + { + return np.array(tensor.HalfVal).reshape(shape); + } else if (tensor.Dtype == DataType.DtFloat) -#pragma warning disable CS0642 // Possible mistaken empty statement - ; -#pragma warning restore CS0642 // Possible mistaken empty statement + { + return np.array(tensor.FloatVal).reshape(shape); + } else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) { - if (tensor.IntVal.Count == 1) - return np.repeat(np.array(tensor.IntVal[0]), num_elements).reshape(shape); + return np.array(tensor.IntVal).reshape(shape); } else if (tensor.Dtype == DataType.DtBool) { - if (tensor.BoolVal.Count == 1) - return np.repeat(np.array(tensor.BoolVal[0]), num_elements).reshape(shape); + return np.array(tensor.BoolVal).reshape(shape); } throw new NotImplementedException("MakeNdarray"); @@ -396,11 +394,11 @@ would not be rank 1.", tensor.op.get_attr("axis"))); tensor.op.graph is FuncGraph func_graph) { int i = 0; - foreach (Tensor capture in func_graph.internal_captures()) + foreach (Tensor capture in func_graph.internal_captures) { if (capture.GetType() == typeof(Tensor)) { - var external_capture = func_graph.external_captures()[i]; + var external_capture = func_graph.external_captures[i]; return constant_value_as_shape(external_capture); } diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs index 20bc99f6..e9fdf18f 100644 --- a/src/TensorFlowNET.Keras/BackendImpl.cs +++ b/src/TensorFlowNET.Keras/BackendImpl.cs @@ -17,8 +17,10 @@ using NumSharp; using System; using System.Collections.Generic; +using Tensorflow.Functions; using Tensorflow.Graphs; using static Tensorflow.Binding; +using static Tensorflow.Graphs.SubGraphUtility; namespace Tensorflow.Keras { @@ -33,6 +35,7 @@ namespace Tensorflow.Keras public Session _SESSION => ops.get_default_session(); public Graph _GRAPH; + FuncGraph _CURRENT_SCRATCH_GRAPH; public Dictionary _GRAPH_LEARNING_PHASES; //Dictionary> PER_GRAPH_LAYER_NAME_UIDS; public bool _MANUAL_VAR_INIT = false; @@ -89,6 +92,14 @@ namespace Tensorflow.Keras return ops.get_default_graph(); } + FuncGraph _scratch_graph() + { + if (_CURRENT_SCRATCH_GRAPH == null) + _CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph"); + + return _CURRENT_SCRATCH_GRAPH; + } + public int get_uid(string prefix) { var graph = tf.get_default_graph(); @@ -168,9 +179,36 @@ namespace Tensorflow.Keras /// /// /// - public NDArray eval_in_eager_or_function(Tensor outputs) + public NDArray eval_in_eager_or_function(Tensors outputs) { - return outputs.eval(); + if (outputs[0].op.type == "Const") + return tensor_util.constant_value(outputs); + + var source_graph = outputs.graph; + using var exec_graph = _scratch_graph(); + var global_graph = get_graph(); + if (source_graph == global_graph && exec_graph != global_graph) + { + var lifted_map = lift_to_graph(outputs, exec_graph, + new List(), + add_sources: true, + handle_captures: true, + base_graph: source_graph); + } + if (outputs[0].op.type == "Placeholder" + || outputs[0].op.type == "StridedSlice") + return exec_graph.external_captures[0].numpy(); + + // Consolidate updates + exec_graph.as_default(); + exec_graph.Inputs = exec_graph.internal_captures; + exec_graph.Outputs = outputs; + + var graph_fn = new ConcreteFunction(exec_graph); + + _CURRENT_SCRATCH_GRAPH = null; + // return outputs.eval(); + throw new NotImplementedException(""); } public class _DummyEagerGraph