| @@ -46,6 +46,15 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| public static void difference_update<T>(this IList<T> list, IList<T> list2) | |||
| { | |||
| foreach(var el in list2) | |||
| { | |||
| if (list.Contains(el)) | |||
| list.Remove(el); | |||
| } | |||
| } | |||
| public static void add<T>(this IList<T> list, T element) | |||
| => list.Add(element); | |||
| @@ -158,6 +167,13 @@ namespace Tensorflow | |||
| return Enumerable.Range(start, end - start); | |||
| } | |||
| public static IEnumerable<T> reversed<T>(IList<T> values) | |||
| { | |||
| var len = values.Count; | |||
| for (int i = len - 1; i >= 0; i--) | |||
| yield return values[i]; | |||
| } | |||
| public static T New<T>() where T : ITensorFlowObject, new() | |||
| { | |||
| var instance = new T(); | |||
| @@ -284,7 +300,7 @@ namespace Tensorflow | |||
| for (int i = 0; i < len; i++) | |||
| yield return (i, values[i]); | |||
| } | |||
| public static IEnumerable<(int, T)> enumerate<T>(IEnumerable<T> values, int start = 0, int step = 1) | |||
| { | |||
| int i = 0; | |||
| @@ -14,7 +14,7 @@ namespace Tensorflow.Functions | |||
| { | |||
| IntPtr _handle; | |||
| FuncGraph func_graph; | |||
| public Tensor[] CapturedInputs => func_graph.external_captures(); | |||
| public Tensor[] CapturedInputs => func_graph.external_captures; | |||
| public string Name | |||
| { | |||
| @@ -37,7 +37,7 @@ namespace Tensorflow.Functions | |||
| func_graph.as_default(); | |||
| } | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs) | |||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | |||
| { | |||
| func_graph = graph; | |||
| @@ -93,7 +93,7 @@ namespace Tensorflow.Functions | |||
| grad_ys: gradients_wrt_outputs.ToArray(), | |||
| src_graph: _func_graph); | |||
| var captures_from_forward = backwards_graph.external_captures() | |||
| var captures_from_forward = backwards_graph.external_captures | |||
| .Where(x => !x.IsEagerTensor && x.graph == _func_graph) | |||
| .ToArray(); | |||
| foreach(var capture in captures_from_forward) | |||
| @@ -105,7 +105,7 @@ namespace Tensorflow.Functions | |||
| var forward_function_name = $"{_FORWARD_PREFIX}_{ops.uid()}"; | |||
| var backward_function_attr = new Dictionary<string, string>(); | |||
| backward_function_attr[FORWARD_FUNCTION_ATTRIBUTE_NAME] = forward_function_name; | |||
| gradients_wrt_outputs.append(backwards_graph.internal_captures()); | |||
| gradients_wrt_outputs.append(backwards_graph.internal_captures); | |||
| backwards_graph.Inputs = gradients_wrt_outputs; | |||
| backwards_graph.Outputs = gradients_wrt_inputs; | |||
| @@ -21,15 +21,20 @@ namespace Tensorflow.Graphs | |||
| public Tensors Outputs { get; set; } = new Tensors(); | |||
| public Dictionary<string, string> Attrs { get; set; } | |||
| public Dictionary<long, (Tensor, Tensor)> _captures | |||
| Dictionary<long, (Tensor, Tensor)> _captures | |||
| = new Dictionary<long, (Tensor, Tensor)>(); | |||
| public Tensor[] external_captures() | |||
| public Tensor[] external_captures | |||
| => _captures.Select(x => x.Value.Item1).ToArray(); | |||
| public (Tensor, Tensor)[] captures | |||
| => _captures.Values.Select(x => x).ToArray(); | |||
| public Tensor[] internal_captures() | |||
| public Tensor[] internal_captures | |||
| => _captures.Select(x => x.Value.Item2).ToArray(); | |||
| public Tensor[] captured_inputs | |||
| => external_captures; | |||
| /// <summary> | |||
| /// Construct a new FuncGraph. | |||
| /// </summary> | |||
| @@ -0,0 +1,175 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using System.Linq; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Graphs | |||
| { | |||
| public class SubGraphUtility | |||
| { | |||
| /// <summary> | |||
| /// Copies the tensor and all its inputs recursively to the outer graph. | |||
| /// </summary> | |||
| /// <param name="tensors"></param> | |||
| /// <param name="graph"></param> | |||
| /// <param name="add_sources"></param> | |||
| /// <param name="handle_captures"></param> | |||
| /// <param name="base_graph"></param> | |||
| /// <returns></returns> | |||
| public static Dictionary<ITensorOrOperation, Operation> lift_to_graph(Tensors init_tensors, | |||
| FuncGraph graph, | |||
| List<Tensor> sources, | |||
| bool add_sources = false, | |||
| bool handle_captures = false, | |||
| Graph base_graph = null, | |||
| Dictionary<ITensorOrOperation, Operation> op_map = null) | |||
| { | |||
| base_graph = base_graph ?? init_tensors[0].graph; | |||
| op_map = op_map ?? new Dictionary<ITensorOrOperation, Operation>(); | |||
| var visited_ops = sources.Select(x => x.op).ToList(); | |||
| foreach (var init_tensor in init_tensors) | |||
| { | |||
| var src = map_subgraph(init_tensor, sources, visited_ops, add_sources); | |||
| sources.AddRange(src); | |||
| } | |||
| var ops_to_copy = new List<Operation>(); | |||
| var marked_ops = new List<Operation>(); | |||
| var ops_to_visit = new Stack<Operation>(init_tensors.Select(x => x.op)); | |||
| var unvisited_ops = new List<Operation>(ops_to_visit.ToList()); | |||
| while (unvisited_ops.Count > 0) | |||
| { | |||
| while(ops_to_visit.Count > 0) | |||
| { | |||
| var op = ops_to_visit.Pop(); | |||
| if (marked_ops.Contains(op)) | |||
| continue; | |||
| marked_ops.Add(op); | |||
| ops_to_copy.append(op); | |||
| foreach(var inp in op.inputs) | |||
| { | |||
| } | |||
| } | |||
| // difference_update | |||
| unvisited_ops.difference_update(marked_ops); | |||
| if (unvisited_ops.Count > 0) | |||
| ops_to_visit.Push(unvisited_ops.Last()); | |||
| } | |||
| // When lifting from one FuncGraph to another, we will need to capture the | |||
| // relevant tensors as well. | |||
| var inverse_captures = new Dictionary<Tensor, Tensor>(); | |||
| Tensor[] internal_captures = null; | |||
| if (base_graph is FuncGraph base_func_graph) | |||
| { | |||
| var captures = base_func_graph.captures; | |||
| foreach (var (external_capture, internal_capture) in captures) | |||
| inverse_captures[internal_capture] = external_capture; | |||
| internal_captures = base_func_graph.internal_captures; | |||
| } | |||
| graph.as_default(); | |||
| var source_ops = new List<Operation>(); | |||
| // Add the sources in the same order as the original graph. | |||
| foreach (var s in internal_captures) | |||
| { | |||
| if (sources.Contains(s)) | |||
| { | |||
| sources.Remove(s); | |||
| source_ops.Add(s.op); | |||
| _copy_source(s: s, | |||
| graph: graph, | |||
| op_map: op_map, | |||
| handle_captures: handle_captures, | |||
| inverse_captures: inverse_captures, | |||
| base_graph: base_graph); | |||
| } | |||
| } | |||
| foreach(var op in reversed(ops_to_copy)) | |||
| { | |||
| if (source_ops.Contains(op) || op_map.ContainsKey(op)) | |||
| continue; | |||
| _copy_non_source(op, graph, op_map, base_graph); | |||
| } | |||
| return op_map; | |||
| } | |||
| static void _copy_source(Tensor s, | |||
| FuncGraph graph, | |||
| Dictionary<ITensorOrOperation, Operation> op_map, | |||
| bool handle_captures, | |||
| Dictionary<Tensor, Tensor> inverse_captures, | |||
| Graph base_graph) | |||
| { | |||
| Tensor copied_placeholder = null; | |||
| if (handle_captures && inverse_captures.ContainsKey(s)) | |||
| copied_placeholder = graph.capture(inverse_captures[s], name: s.op.name); | |||
| else | |||
| throw new NotImplementedException(""); | |||
| op_map[s] = copied_placeholder; | |||
| // Add an entry for the op of the source tensor so that if there are any nodes | |||
| // depending on that op via control dependencies it can work correctly. | |||
| op_map[s.op] = copied_placeholder.op; | |||
| } | |||
| static void _copy_non_source(Operation op, FuncGraph graph, Dictionary<ITensorOrOperation, Operation> op_map, Graph base_graph) | |||
| { | |||
| Operation copied_op = null; | |||
| var copied_inputs = new Tensors(); | |||
| tf_with(ops.control_dependencies(new object[] { op }), delegate | |||
| { | |||
| // Create a new op in the destination graph if it doesn't exist before. | |||
| var attrs = new Dictionary<string, AttrValue>(); | |||
| foreach (var attr_def in op.node_def.Attr) | |||
| attrs[attr_def.Key] = attr_def.Value; | |||
| copied_op = graph.create_op(op.type, | |||
| copied_inputs, | |||
| dtypes: op.outputs.Select(x => x.dtype).ToArray(), | |||
| attrs: attrs, | |||
| name: op.name); | |||
| }); | |||
| op_map[op] = copied_op; | |||
| foreach (var (i, o) in enumerate(op.outputs)) | |||
| op_map[o] = copied_op.outputs[i]; | |||
| } | |||
| /// <summary> | |||
| /// Walk a Graph and capture the subgraph between init_tensor and sources. | |||
| /// </summary> | |||
| /// <param name="init_tensor"></param> | |||
| /// <param name="add_sources"></param> | |||
| public static List<Tensor> map_subgraph(Tensor init_tensor, | |||
| List<Tensor> sources, | |||
| List<Operation> visited_ops, | |||
| bool add_sources) | |||
| { | |||
| var ops_to_visit = new Stack<Operation>(); | |||
| ops_to_visit.Push(init_tensor.op); | |||
| var extra_sources = new List<Tensor>(); | |||
| while (ops_to_visit.Count > 0) | |||
| { | |||
| var op = ops_to_visit.Pop(); | |||
| if (visited_ops.Contains(op)) | |||
| continue; | |||
| visited_ops.Add(op); | |||
| bool should_raise = false; | |||
| if (should_raise) | |||
| throw new RuntimeError($"Unable to lift tensor {init_tensor.name}."); | |||
| if(op.type == "Placeholder") | |||
| { | |||
| extra_sources.AddRange(op.outputs); | |||
| } | |||
| foreach(var inp in op.inputs) | |||
| { | |||
| } | |||
| } | |||
| return extra_sources; | |||
| } | |||
| } | |||
| } | |||
| @@ -873,22 +873,6 @@ namespace Tensorflow | |||
| return _op.output; | |||
| } | |||
| public static Tensor mul(Tensor x, Tensor y, string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Mul", name, | |||
| null, | |||
| x, y); | |||
| return results[0]; | |||
| } | |||
| var _op = tf.OpDefLib._apply_op_helper("Mul", name, args: new { x, y }); | |||
| return _op.output; | |||
| } | |||
| public static Tensor mul<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| @@ -44,6 +44,23 @@ namespace Tensorflow | |||
| public static Tensor add<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| => gen_math_ops.add(x, y, name); | |||
| public static Tensor add_v2(Tensor x, Tensor y, string name = null) | |||
| => tf.Context.RunInAutoMode2( | |||
| () => tf.OpDefLib._apply_op_helper("AddV2", name, new { x, y }).output, | |||
| () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "AddV2", name, | |||
| null, | |||
| x, y).FirstOrDefault(), | |||
| (op) => | |||
| { | |||
| var attrs = new object[] | |||
| { | |||
| "T", op.get_attr<TF_DataType>("T") | |||
| }; | |||
| tf.Runner.RecordGradient("AddV2", op.inputs, attrs, op.outputs); | |||
| }, | |||
| new Tensors(x, y)); | |||
| public static Tensor add_v2<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| => gen_math_ops.add_v2(x, y, name); | |||
| @@ -251,6 +268,23 @@ namespace Tensorflow | |||
| public static Tensor sqrt(Tensor x, string name = null) | |||
| => gen_math_ops.sqrt(x, name: name); | |||
| public static Tensor multiply(Tensor x, Tensor y, string name = null) | |||
| => tf.Context.RunInAutoMode2( | |||
| () => tf.OpDefLib._apply_op_helper("Mul", name, new { x, y }).output, | |||
| () => tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||
| "Mul", name, | |||
| null, | |||
| x, y).FirstOrDefault(), | |||
| (op) => | |||
| { | |||
| var attrs = new object[] | |||
| { | |||
| "T", op.get_attr<TF_DataType>("T") | |||
| }; | |||
| tf.Runner.RecordGradient("Mul", op.inputs, attrs, op.outputs); | |||
| }, | |||
| new Tensors(x, y)); | |||
| public static Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null) | |||
| => gen_math_ops.mul(x, y, name: name); | |||
| @@ -309,25 +309,19 @@ namespace Tensorflow | |||
| private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | |||
| { | |||
| TF_DataType dtype = TF_DataType.DtInvalid; | |||
| bool switchToGraphModeTemp = !tf.executing_eagerly(); | |||
| if (x is Tensor tl) | |||
| { | |||
| dtype = tl.dtype.as_base_dtype(); | |||
| switchToGraphModeTemp = switchToGraphModeTemp || !tl.IsEagerTensor; | |||
| } | |||
| if (y is Tensor tr) | |||
| { | |||
| dtype = tr.dtype.as_base_dtype(); | |||
| switchToGraphModeTemp = switchToGraphModeTemp || !tr.IsEagerTensor; | |||
| } | |||
| return tf_with(ops.name_scope(null, name, new { x, y }), scope => | |||
| { | |||
| if (switchToGraphModeTemp) | |||
| tf.Context.graph_mode(); | |||
| Tensor result; | |||
| var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | |||
| var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); | |||
| @@ -347,7 +341,7 @@ namespace Tensorflow | |||
| result = math_ops.truediv(x1, y1, name: scope); | |||
| break; | |||
| case "mul": | |||
| result = gen_math_ops.mul(x1, y1, name: scope); | |||
| result = math_ops.multiply(x1, y1, name: scope); | |||
| break; | |||
| case "sub": | |||
| result = gen_math_ops.sub(x1, y1, name: scope); | |||
| @@ -359,9 +353,6 @@ namespace Tensorflow | |||
| throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); | |||
| } | |||
| if (switchToGraphModeTemp) | |||
| tf.Context.restore_mode(); | |||
| return result; | |||
| }); | |||
| } | |||
| @@ -69,27 +69,25 @@ namespace Tensorflow | |||
| int num_elements = np.prod(shape); | |||
| var tensor_dtype = tensor.Dtype.as_numpy_dtype(); | |||
| if (tensor.TensorContent.Length > 0) | |||
| if (shape.Length > 0 && tensor.TensorContent.Length > 0) | |||
| { | |||
| return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype).reshape(shape); | |||
| } | |||
| else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) | |||
| #pragma warning disable CS0642 // Possible mistaken empty statement | |||
| ; | |||
| #pragma warning restore CS0642 // Possible mistaken empty statement | |||
| { | |||
| return np.array(tensor.HalfVal).reshape(shape); | |||
| } | |||
| else if (tensor.Dtype == DataType.DtFloat) | |||
| #pragma warning disable CS0642 // Possible mistaken empty statement | |||
| ; | |||
| #pragma warning restore CS0642 // Possible mistaken empty statement | |||
| { | |||
| return np.array(tensor.FloatVal).reshape(shape); | |||
| } | |||
| else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) | |||
| { | |||
| if (tensor.IntVal.Count == 1) | |||
| return np.repeat(np.array(tensor.IntVal[0]), num_elements).reshape(shape); | |||
| return np.array(tensor.IntVal).reshape(shape); | |||
| } | |||
| else if (tensor.Dtype == DataType.DtBool) | |||
| { | |||
| if (tensor.BoolVal.Count == 1) | |||
| return np.repeat(np.array(tensor.BoolVal[0]), num_elements).reshape(shape); | |||
| return np.array(tensor.BoolVal).reshape(shape); | |||
| } | |||
| throw new NotImplementedException("MakeNdarray"); | |||
| @@ -396,11 +394,11 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||
| tensor.op.graph is FuncGraph func_graph) | |||
| { | |||
| int i = 0; | |||
| foreach (Tensor capture in func_graph.internal_captures()) | |||
| foreach (Tensor capture in func_graph.internal_captures) | |||
| { | |||
| if (capture.GetType() == typeof(Tensor)) | |||
| { | |||
| var external_capture = func_graph.external_captures()[i]; | |||
| var external_capture = func_graph.external_captures[i]; | |||
| return constant_value_as_shape(external_capture); | |||
| } | |||
| @@ -17,8 +17,10 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.Graphs.SubGraphUtility; | |||
| namespace Tensorflow.Keras | |||
| { | |||
| @@ -33,6 +35,7 @@ namespace Tensorflow.Keras | |||
| public Session _SESSION => ops.get_default_session(); | |||
| public Graph _GRAPH; | |||
| FuncGraph _CURRENT_SCRATCH_GRAPH; | |||
| public Dictionary<Graph, GraphLearningPhase> _GRAPH_LEARNING_PHASES; | |||
| //Dictionary<Graph, Dictionary<string, int>> PER_GRAPH_LAYER_NAME_UIDS; | |||
| public bool _MANUAL_VAR_INIT = false; | |||
| @@ -89,6 +92,14 @@ namespace Tensorflow.Keras | |||
| return ops.get_default_graph(); | |||
| } | |||
| FuncGraph _scratch_graph() | |||
| { | |||
| if (_CURRENT_SCRATCH_GRAPH == null) | |||
| _CURRENT_SCRATCH_GRAPH = new FuncGraph("keras_scratch_graph"); | |||
| return _CURRENT_SCRATCH_GRAPH; | |||
| } | |||
| public int get_uid(string prefix) | |||
| { | |||
| var graph = tf.get_default_graph(); | |||
| @@ -168,9 +179,36 @@ namespace Tensorflow.Keras | |||
| /// </summary> | |||
| /// <param name="outputs"></param> | |||
| /// <returns></returns> | |||
| public NDArray eval_in_eager_or_function(Tensor outputs) | |||
| public NDArray eval_in_eager_or_function(Tensors outputs) | |||
| { | |||
| return outputs.eval(); | |||
| if (outputs[0].op.type == "Const") | |||
| return tensor_util.constant_value(outputs); | |||
| var source_graph = outputs.graph; | |||
| using var exec_graph = _scratch_graph(); | |||
| var global_graph = get_graph(); | |||
| if (source_graph == global_graph && exec_graph != global_graph) | |||
| { | |||
| var lifted_map = lift_to_graph(outputs, exec_graph, | |||
| new List<Tensor>(), | |||
| add_sources: true, | |||
| handle_captures: true, | |||
| base_graph: source_graph); | |||
| } | |||
| if (outputs[0].op.type == "Placeholder" | |||
| || outputs[0].op.type == "StridedSlice") | |||
| return exec_graph.external_captures[0].numpy(); | |||
| // Consolidate updates | |||
| exec_graph.as_default(); | |||
| exec_graph.Inputs = exec_graph.internal_captures; | |||
| exec_graph.Outputs = outputs; | |||
| var graph_fn = new ConcreteFunction(exec_graph); | |||
| _CURRENT_SCRATCH_GRAPH = null; | |||
| // return outputs.eval(); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public class _DummyEagerGraph | |||