using System; using System.Collections.Generic; using System.Text; using System.Linq; using static Tensorflow.Binding; namespace Tensorflow.Graphs { public class SubGraphUtility { /// /// Copies the tensor and all its inputs recursively to the outer graph. /// /// /// /// /// /// /// public static Dictionary lift_to_graph(Tensors init_tensors, FuncGraph graph, List sources, bool add_sources = false, bool handle_captures = false, Graph base_graph = null, Dictionary op_map = null) { base_graph = base_graph ?? init_tensors[0].graph; op_map = op_map ?? new Dictionary(); var visited_ops = sources.Select(x => x.op).ToList(); foreach (var init_tensor in init_tensors) { var src = map_subgraph(init_tensor, sources, visited_ops, add_sources); sources.AddRange(src); } var ops_to_copy = new List(); var marked_ops = new List(); var ops_to_visit = new Stack(init_tensors.Select(x => x.op)); var unvisited_ops = new List(ops_to_visit.ToList()); while (unvisited_ops.Count > 0) { while(ops_to_visit.Count > 0) { var op = ops_to_visit.Pop(); if (marked_ops.Contains(op)) continue; marked_ops.Add(op); ops_to_copy.append(op); foreach(var inp in op.inputs) { } } // difference_update unvisited_ops.difference_update(marked_ops); if (unvisited_ops.Count > 0) ops_to_visit.Push(unvisited_ops.Last()); } // When lifting from one FuncGraph to another, we will need to capture the // relevant tensors as well. var inverse_captures = new Dictionary(); Tensor[] internal_captures = null; if (base_graph is FuncGraph base_func_graph) { var captures = base_func_graph.captures; foreach (var (external_capture, internal_capture) in captures) inverse_captures[internal_capture] = external_capture; internal_captures = base_func_graph.internal_captures; } graph.as_default(); var source_ops = new List(); // Add the sources in the same order as the original graph. foreach (var s in internal_captures) { if (sources.Contains(s)) { sources.Remove(s); source_ops.Add(s.op); _copy_source(s: s, graph: graph, op_map: op_map, handle_captures: handle_captures, inverse_captures: inverse_captures, base_graph: base_graph); } } foreach(var op in reversed(ops_to_copy)) { if (source_ops.Contains(op) || op_map.ContainsKey(op)) continue; _copy_non_source(op, graph, op_map, base_graph); } graph.Exit(); return op_map; } static void _copy_source(Tensor s, FuncGraph graph, Dictionary op_map, bool handle_captures, Dictionary inverse_captures, Graph base_graph) { Tensor copied_placeholder = null; if (handle_captures && inverse_captures.ContainsKey(s)) copied_placeholder = graph.capture(inverse_captures[s], name: s.op.name); else throw new NotImplementedException(""); op_map[s] = copied_placeholder; // Add an entry for the op of the source tensor so that if there are any nodes // depending on that op via control dependencies it can work correctly. op_map[s.op] = copied_placeholder.op; } static void _copy_non_source(Operation op, FuncGraph graph, Dictionary op_map, Graph base_graph) { Operation copied_op = null; var copied_inputs = new Tensors(); tf_with(ops.control_dependencies(new object[] { op }), delegate { // Create a new op in the destination graph if it doesn't exist before. var attrs = new Dictionary(); foreach (var attr_def in op.node_def.Attr) attrs[attr_def.Key] = attr_def.Value; copied_op = graph.create_op(op.type, copied_inputs, dtypes: op.outputs.Select(x => x.dtype).ToArray(), attrs: attrs, name: op.name); }); op_map[op] = copied_op; foreach (var (i, o) in enumerate(op.outputs)) op_map[o] = copied_op.outputs[i]; } /// /// Walk a Graph and capture the subgraph between init_tensor and sources. /// /// /// public static List map_subgraph(Tensor init_tensor, List sources, List visited_ops, bool add_sources) { var ops_to_visit = new Stack(); ops_to_visit.Push(init_tensor.op); var extra_sources = new List(); while (ops_to_visit.Count > 0) { var op = ops_to_visit.Pop(); if (visited_ops.Contains(op)) continue; visited_ops.Add(op); bool should_raise = false; if (should_raise) throw new RuntimeError($"Unable to lift tensor {init_tensor.name}."); if(op.type == "Placeholder") { extra_sources.AddRange(op.outputs); } foreach(var inp in op.inputs) { } } return extra_sources; } } }