Browse Source

lift_to_graph

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
e75a111620
10 changed files with 290 additions and 49 deletions
  1. +17
    -1
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  3. +2
    -2
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  4. +8
    -3
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  5. +175
    -0
      src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs
  6. +0
    -16
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  7. +34
    -0
      src/TensorFlowNET.Core/Operations/math_ops.cs
  8. +1
    -10
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  9. +11
    -13
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  10. +40
    -2
      src/TensorFlowNET.Keras/BackendImpl.cs

+ 17
- 1
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -46,6 +46,15 @@ namespace Tensorflow
}
}

public static void difference_update<T>(this IList<T> list, IList<T> list2)
{
foreach(var el in list2)
{
if (list.Contains(el))
list.Remove(el);
}
}

public static void add<T>(this IList<T> list, T element)
=> list.Add(element);

@@ -158,6 +167,13 @@ namespace Tensorflow
return Enumerable.Range(start, end - start);
}

public static IEnumerable<T> reversed<T>(IList<T> values)
{
var len = values.Count;
for (int i = len - 1; i >= 0; i--)
yield return values[i];
}

public static T New<T>() where T : ITensorFlowObject, new()
{
var instance = new T();
@@ -284,7 +300,7 @@ namespace Tensorflow
for (int i = 0; i < len; i++)
yield return (i, values[i]);
}
public static IEnumerable<(int, T)> enumerate<T>(IEnumerable<T> values, int start = 0, int step = 1)
{
int i = 0;


+ 2
- 2
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -14,7 +14,7 @@ namespace Tensorflow.Functions
{
IntPtr _handle;
FuncGraph func_graph;
public Tensor[] CapturedInputs => func_graph.external_captures();
public Tensor[] CapturedInputs => func_graph.external_captures;

public string Name
{
@@ -37,7 +37,7 @@ namespace Tensorflow.Functions
func_graph.as_default();
}

public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs)
public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
{
func_graph = graph;



+ 2
- 2
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -93,7 +93,7 @@ namespace Tensorflow.Functions
grad_ys: gradients_wrt_outputs.ToArray(),
src_graph: _func_graph);

var captures_from_forward = backwards_graph.external_captures()
var captures_from_forward = backwards_graph.external_captures
.Where(x => !x.IsEagerTensor && x.graph == _func_graph)
.ToArray();
foreach(var capture in captures_from_forward)
@@ -105,7 +105,7 @@ namespace Tensorflow.Functions
var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}";
var backward_function_attr = new Dictionary<string, string>();
backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name;
gradients_wrt_outputs.append(backwards_graph.internal_captures());
gradients_wrt_outputs.append(backwards_graph.internal_captures);
backwards_graph.Inputs = gradients_wrt_outputs;
backwards_graph.Outputs = gradients_wrt_inputs;



+ 8
- 3
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -21,15 +21,20 @@ namespace Tensorflow.Graphs
public Tensors Outputs { get; set; } = new Tensors();
public Dictionary<string, string> Attrs { get; set; }

public Dictionary<long, (Tensor, Tensor)> _captures
Dictionary<long, (Tensor, Tensor)> _captures
= new Dictionary<long, (Tensor, Tensor)>();

public Tensor[] external_captures()
public Tensor[] external_captures
=> _captures.Select(x => x.Value.Item1).ToArray();
public (Tensor, Tensor)[] captures
=> _captures.Values.Select(x => x).ToArray();

public Tensor[] internal_captures()
public Tensor[] internal_captures
=> _captures.Select(x => x.Value.Item2).ToArray();

public Tensor[] captured_inputs
=> external_captures;

/// <summary>
/// Construct a new FuncGraph.
/// </summary>


+ 175
- 0
src/TensorFlowNET.Core/Graphs/SubGraphUtility.cs View File

@@ -0,0 +1,175 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Linq;
using static Tensorflow.Binding;

namespace Tensorflow.Graphs
{
public class SubGraphUtility
{
/// <summary>
/// Copies the tensor and all its inputs recursively to the outer graph.
/// </summary>
/// <param name="tensors"></param>
/// <param name="graph"></param>
/// <param name="add_sources"></param>
/// <param name="handle_captures"></param>
/// <param name="base_graph"></param>
/// <returns></returns>
public static Dictionary<ITensorOrOperation, Operation> lift_to_graph(Tensors init_tensors,
FuncGraph graph,
List<Tensor> sources,
bool add_sources = false,
bool handle_captures = false,
Graph base_graph = null,
Dictionary<ITensorOrOperation, Operation> op_map = null)
{
base_graph = base_graph ?? init_tensors[0].graph;
op_map = op_map ?? new Dictionary<ITensorOrOperation, Operation>();
var visited_ops = sources.Select(x => x.op).ToList();
foreach (var init_tensor in init_tensors)
{
var src = map_subgraph(init_tensor, sources, visited_ops, add_sources);
sources.AddRange(src);
}

var ops_to_copy = new List<Operation>();
var marked_ops = new List<Operation>();
var ops_to_visit = new Stack<Operation>(init_tensors.Select(x => x.op));
var unvisited_ops = new List<Operation>(ops_to_visit.ToList());
while (unvisited_ops.Count > 0)
{
while(ops_to_visit.Count > 0)
{
var op = ops_to_visit.Pop();
if (marked_ops.Contains(op))
continue;
marked_ops.Add(op);
ops_to_copy.append(op);
foreach(var inp in op.inputs)
{

}
}
// difference_update
unvisited_ops.difference_update(marked_ops);
if (unvisited_ops.Count > 0)
ops_to_visit.Push(unvisited_ops.Last());
}

// When lifting from one FuncGraph to another, we will need to capture the
// relevant tensors as well.
var inverse_captures = new Dictionary<Tensor, Tensor>();
Tensor[] internal_captures = null;
if (base_graph is FuncGraph base_func_graph)
{
var captures = base_func_graph.captures;
foreach (var (external_capture, internal_capture) in captures)
inverse_captures[internal_capture] = external_capture;
internal_captures = base_func_graph.internal_captures;
}

graph.as_default();
var source_ops = new List<Operation>();
// Add the sources in the same order as the original graph.
foreach (var s in internal_captures)
{
if (sources.Contains(s))
{
sources.Remove(s);
source_ops.Add(s.op);
_copy_source(s: s,
graph: graph,
op_map: op_map,
handle_captures: handle_captures,
inverse_captures: inverse_captures,
base_graph: base_graph);
}
}

foreach(var op in reversed(ops_to_copy))
{
if (source_ops.Contains(op) || op_map.ContainsKey(op))
continue;
_copy_non_source(op, graph, op_map, base_graph);
}

return op_map;
}

static void _copy_source(Tensor s,
FuncGraph graph,
Dictionary<ITensorOrOperation, Operation> op_map,
bool handle_captures,
Dictionary<Tensor, Tensor> inverse_captures,
Graph base_graph)
{
Tensor copied_placeholder = null;
if (handle_captures && inverse_captures.ContainsKey(s))
copied_placeholder = graph.capture(inverse_captures[s], name: s.op.name);
else
throw new NotImplementedException("");
op_map[s] = copied_placeholder;
// Add an entry for the op of the source tensor so that if there are any nodes
// depending on that op via control dependencies it can work correctly.
op_map[s.op] = copied_placeholder.op;
}

static void _copy_non_source(Operation op, FuncGraph graph, Dictionary<ITensorOrOperation, Operation> op_map, Graph base_graph)
{
Operation copied_op = null;
var copied_inputs = new Tensors();
tf_with(ops.control_dependencies(new object[] { op }), delegate
{
// Create a new op in the destination graph if it doesn't exist before.
var attrs = new Dictionary<string, AttrValue>();
foreach (var attr_def in op.node_def.Attr)
attrs[attr_def.Key] = attr_def.Value;

copied_op = graph.create_op(op.type,
copied_inputs,
dtypes: op.outputs.Select(x => x.dtype).ToArray(),
attrs: attrs,
name: op.name);
});
op_map[op] = copied_op;
foreach (var (i, o) in enumerate(op.outputs))
op_map[o] = copied_op.outputs[i];
}

/// <summary>
/// Walk a Graph and capture the subgraph between init_tensor and sources.
/// </summary>
/// <param name="init_tensor"></param>
/// <param name="add_sources"></param>
public static List<Tensor> map_subgraph(Tensor init_tensor,
List<Tensor> sources,
List<Operation> visited_ops,
bool add_sources)
{
var ops_to_visit = new Stack<Operation>();
ops_to_visit.Push(init_tensor.op);
var extra_sources = new List<Tensor>();
while (ops_to_visit.Count > 0)
{
var op = ops_to_visit.Pop();
if (visited_ops.Contains(op))
continue;
visited_ops.Add(op);
bool should_raise = false;
if (should_raise)
throw new RuntimeError($"Unable to lift tensor {init_tensor.name}.");
if(op.type == "Placeholder")
{
extra_sources.AddRange(op.outputs);
}
foreach(var inp in op.inputs)
{

}
}
return extra_sources;
}
}
}

+ 0
- 16
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -873,22 +873,6 @@ namespace Tensorflow
return _op.output;
}

public static Tensor mul(Tensor x, Tensor y, string name = null)
{
if (tf.Context.executing_eagerly())
{
var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Mul", name,
null,
x, y);
return results[0];
}

var _op = tf.OpDefLib._apply_op_helper("Mul", name, args: new { x, y });

return _op.output;
}

public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null)
{
if (tf.Context.executing_eagerly())


+ 34
- 0
src/TensorFlowNET.Core/Operations/math_ops.cs View File

@@ -44,6 +44,23 @@ namespace Tensorflow
public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.add(x, y, name);

public static Tensor add_v2(Tensor x, Tensor y, string name = null)
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("AddV2", name, new { x, y }).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"AddV2", name,
null,
x, y).FirstOrDefault(),
(op) =>
{
var attrs = new object[]
{
"T", op.get_attr<TF_DataType>("T")
};
tf.Runner.RecordGradient("AddV2", op.inputs, attrs, op.outputs);
},
new Tensors(x, y));

public static Tensor add_v2<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.add_v2(x, y, name);

@@ -251,6 +268,23 @@ namespace Tensorflow
public static Tensor sqrt(Tensor x, string name = null)
=> gen_math_ops.sqrt(x, name: name);

public static Tensor multiply(Tensor x, Tensor y, string name = null)
=> tf.Context.RunInAutoMode2(
() => tf.OpDefLib._apply_op_helper("Mul", name, new { x, y }).output,
() => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Mul", name,
null,
x, y).FirstOrDefault(),
(op) =>
{
var attrs = new object[]
{
"T", op.get_attr<TF_DataType>("T")
};
tf.Runner.RecordGradient("Mul", op.inputs, attrs, op.outputs);
},
new Tensors(x, y));

public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null)
=> gen_math_ops.mul(x, y, name: name);



+ 1
- 10
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -309,25 +309,19 @@ namespace Tensorflow
private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y)
{
TF_DataType dtype = TF_DataType.DtInvalid;
bool switchToGraphModeTemp = !tf.executing_eagerly();

if (x is Tensor tl)
{
dtype = tl.dtype.as_base_dtype();
switchToGraphModeTemp = switchToGraphModeTemp || !tl.IsEagerTensor;
}

if (y is Tensor tr)
{
dtype = tr.dtype.as_base_dtype();
switchToGraphModeTemp = switchToGraphModeTemp || !tr.IsEagerTensor;
}

return tf_with(ops.name_scope(null, name, new { x, y }), scope =>
{
if (switchToGraphModeTemp)
tf.Context.graph_mode();

Tensor result;
var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x");
var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y");
@@ -347,7 +341,7 @@ namespace Tensorflow
result = math_ops.truediv(x1, y1, name: scope);
break;
case "mul":
result = gen_math_ops.mul(x1, y1, name: scope);
result = math_ops.multiply(x1, y1, name: scope);
break;
case "sub":
result = gen_math_ops.sub(x1, y1, name: scope);
@@ -359,9 +353,6 @@ namespace Tensorflow
throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}");
}

if (switchToGraphModeTemp)
tf.Context.restore_mode();

return result;
});
}


+ 11
- 13
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -69,27 +69,25 @@ namespace Tensorflow
int num_elements = np.prod(shape);
var tensor_dtype = tensor.Dtype.as_numpy_dtype();

if (tensor.TensorContent.Length > 0)
if (shape.Length > 0 && tensor.TensorContent.Length > 0)
{
return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype).reshape(shape);
}
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16)
#pragma warning disable CS0642 // Possible mistaken empty statement
;
#pragma warning restore CS0642 // Possible mistaken empty statement
{
return np.array(tensor.HalfVal).reshape(shape);
}
else if (tensor.Dtype == DataType.DtFloat)
#pragma warning disable CS0642 // Possible mistaken empty statement
;
#pragma warning restore CS0642 // Possible mistaken empty statement
{
return np.array(tensor.FloatVal).reshape(shape);
}
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype))
{
if (tensor.IntVal.Count == 1)
return np.repeat(np.array(tensor.IntVal[0]), num_elements).reshape(shape);
return np.array(tensor.IntVal).reshape(shape);
}
else if (tensor.Dtype == DataType.DtBool)
{
if (tensor.BoolVal.Count == 1)
return np.repeat(np.array(tensor.BoolVal[0]), num_elements).reshape(shape);
return np.array(tensor.BoolVal).reshape(shape);
}

throw new NotImplementedException("MakeNdarray");
@@ -396,11 +394,11 @@ would not be rank 1.", tensor.op.get_attr("axis")));
tensor.op.graph is FuncGraph func_graph)
{
int i = 0;
foreach (Tensor capture in func_graph.internal_captures())
foreach (Tensor capture in func_graph.internal_captures)
{
if (capture.GetType() == typeof(Tensor))
{
var external_capture = func_graph.external_captures()[i];
var external_capture = func_graph.external_captures[i];
return constant_value_as_shape(external_capture);
}



+ 40
- 2
src/TensorFlowNET.Keras/BackendImpl.cs View File

@@ -17,8 +17,10 @@
using NumSharp;
using System;
using System.Collections.Generic;
using Tensorflow.Functions;
using Tensorflow.Graphs;
using static Tensorflow.Binding;
using static Tensorflow.Graphs.SubGraphUtility;

namespace Tensorflow.Keras
{
@@ -33,6 +35,7 @@ namespace Tensorflow.Keras
public Session _SESSION => ops.get_default_session();

public Graph _GRAPH;
FuncGraph _CURRENT_SCRATCH_GRAPH;
public Dictionary<Graph, GraphLearningPhase> _GRAPH_LEARNING_PHASES;
//Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS;
public bool _MANUAL_VAR_INIT = false;
@@ -89,6 +92,14 @@ namespace Tensorflow.Keras
return ops.get_default_graph();
}

FuncGraph _scratch_graph()
{
if (_CURRENT_SCRATCH_GRAPH == null)
_CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph");
return _CURRENT_SCRATCH_GRAPH;
}

public int get_uid(string prefix)
{
var graph = tf.get_default_graph();
@@ -168,9 +179,36 @@ namespace Tensorflow.Keras
/// </summary>
/// <param name="outputs"></param>
/// <returns></returns>
public NDArray eval_in_eager_or_function(Tensor outputs)
public NDArray eval_in_eager_or_function(Tensors outputs)
{
return outputs.eval();
if (outputs[0].op.type == "Const")
return tensor_util.constant_value(outputs);
var source_graph = outputs.graph;
using var exec_graph = _scratch_graph();
var global_graph = get_graph();
if (source_graph == global_graph && exec_graph != global_graph)
{
var lifted_map = lift_to_graph(outputs, exec_graph,
new List<Tensor>(),
add_sources: true,
handle_captures: true,
base_graph: source_graph);
}
if (outputs[0].op.type == "Placeholder"
|| outputs[0].op.type == "StridedSlice")
return exec_graph.external_captures[0].numpy();

// Consolidate updates
exec_graph.as_default();
exec_graph.Inputs = exec_graph.internal_captures;
exec_graph.Outputs = outputs;
var graph_fn = new ConcreteFunction(exec_graph);

_CURRENT_SCRATCH_GRAPH = null;
// return outputs.eval();
throw new NotImplementedException("");
}

public class _DummyEagerGraph


Loading…
Cancel
Save