using System;
using System.Collections.Generic;
using System.Text;
using System.Linq;
using static Tensorflow.Binding;
namespace Tensorflow.Graphs
{
public class SubGraphUtility
{
///
/// Copies the tensor and all its inputs recursively to the outer graph.
///
///
///
///
///
///
///
public static Dictionary lift_to_graph(Tensors init_tensors,
FuncGraph graph,
List sources,
bool add_sources = false,
bool handle_captures = false,
Graph base_graph = null,
Dictionary op_map = null)
{
base_graph = base_graph ?? init_tensors[0].graph;
op_map = op_map ?? new Dictionary();
var visited_ops = sources.Select(x => x.op).ToList();
foreach (var init_tensor in init_tensors)
{
var src = map_subgraph(init_tensor, sources, visited_ops, add_sources);
sources.AddRange(src);
}
var ops_to_copy = new List();
var marked_ops = new List();
var ops_to_visit = new Stack(init_tensors.Select(x => x.op));
var unvisited_ops = new List(ops_to_visit.ToList());
while (unvisited_ops.Count > 0)
{
while(ops_to_visit.Count > 0)
{
var op = ops_to_visit.Pop();
if (marked_ops.Contains(op))
continue;
marked_ops.Add(op);
ops_to_copy.append(op);
foreach(var inp in op.inputs)
{
}
}
// difference_update
unvisited_ops.difference_update(marked_ops);
if (unvisited_ops.Count > 0)
ops_to_visit.Push(unvisited_ops.Last());
}
// When lifting from one FuncGraph to another, we will need to capture the
// relevant tensors as well.
var inverse_captures = new Dictionary();
Tensor[] internal_captures = null;
if (base_graph is FuncGraph base_func_graph)
{
var captures = base_func_graph.captures;
foreach (var (external_capture, internal_capture) in captures)
inverse_captures[internal_capture] = external_capture;
internal_captures = base_func_graph.internal_captures;
}
graph.as_default();
var source_ops = new List();
// Add the sources in the same order as the original graph.
foreach (var s in internal_captures)
{
if (sources.Contains(s))
{
sources.Remove(s);
source_ops.Add(s.op);
_copy_source(s: s,
graph: graph,
op_map: op_map,
handle_captures: handle_captures,
inverse_captures: inverse_captures,
base_graph: base_graph);
}
}
foreach(var op in reversed(ops_to_copy))
{
if (source_ops.Contains(op) || op_map.ContainsKey(op))
continue;
_copy_non_source(op, graph, op_map, base_graph);
}
graph.Exit();
return op_map;
}
static void _copy_source(Tensor s,
FuncGraph graph,
Dictionary op_map,
bool handle_captures,
Dictionary inverse_captures,
Graph base_graph)
{
Tensor copied_placeholder = null;
if (handle_captures && inverse_captures.ContainsKey(s))
copied_placeholder = graph.capture(inverse_captures[s], name: s.op.name);
else
throw new NotImplementedException("");
op_map[s] = copied_placeholder;
// Add an entry for the op of the source tensor so that if there are any nodes
// depending on that op via control dependencies it can work correctly.
op_map[s.op] = copied_placeholder.op;
}
static void _copy_non_source(Operation op, FuncGraph graph, Dictionary op_map, Graph base_graph)
{
Operation copied_op = null;
var copied_inputs = new Tensors();
tf_with(ops.control_dependencies(new object[] { op }), delegate
{
// Create a new op in the destination graph if it doesn't exist before.
var attrs = new Dictionary();
foreach (var attr_def in op.node_def.Attr)
attrs[attr_def.Key] = attr_def.Value;
copied_op = graph.create_op(op.type,
copied_inputs,
dtypes: op.outputs.Select(x => x.dtype).ToArray(),
attrs: attrs,
name: op.name);
});
op_map[op] = copied_op;
foreach (var (i, o) in enumerate(op.outputs))
op_map[o] = copied_op.outputs[i];
}
///
/// Walk a Graph and capture the subgraph between init_tensor and sources.
///
///
///
public static List map_subgraph(Tensor init_tensor,
List sources,
List visited_ops,
bool add_sources)
{
var ops_to_visit = new Stack();
ops_to_visit.Push(init_tensor.op);
var extra_sources = new List();
while (ops_to_visit.Count > 0)
{
var op = ops_to_visit.Pop();
if (visited_ops.Contains(op))
continue;
visited_ops.Add(op);
bool should_raise = false;
if (should_raise)
throw new RuntimeError($"Unable to lift tensor {init_tensor.name}.");
if(op.type == "Placeholder")
{
extra_sources.AddRange(op.outputs);
}
foreach(var inp in op.inputs)
{
}
}
return extra_sources;
}
}
}