| @@ -22,6 +22,7 @@ using System.ComponentModel; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using NumSharp.Utilities; | |||
| using System.Runtime.CompilerServices; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -50,7 +51,7 @@ namespace Tensorflow | |||
| => list.Add(element); | |||
| public static void append<T>(this IList<T> list, T element) | |||
| => list.Add(element); | |||
| => list.Insert(list.Count, element); | |||
| public static T[] concat<T>(this IList<T> list1, IList<T> list2) | |||
| { | |||
| @@ -407,5 +408,37 @@ namespace Tensorflow | |||
| return true; | |||
| return false; | |||
| } | |||
| public static bool issubset<T>(this IEnumerable<T> subset, IEnumerable<T> src) | |||
| { | |||
| bool issubset = true; | |||
| foreach (var element in subset) | |||
| { | |||
| if (!src.Contains(element)) | |||
| { | |||
| issubset = false; | |||
| continue; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||
| { | |||
| if (dic.ContainsKey(key)) | |||
| return dic[key]; | |||
| dic[key] = value; | |||
| return value; | |||
| } | |||
| public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||
| { | |||
| if (dic.ContainsKey(key)) | |||
| return dic[key]; | |||
| return value; | |||
| } | |||
| } | |||
| } | |||
| @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||
| _channels_first = args.DataFormat == "channels_first"; | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| if (_channels_first) | |||
| { | |||
| @@ -4,6 +4,7 @@ using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -21,6 +22,11 @@ namespace Tensorflow.Keras.Engine | |||
| List<Layer> _input_layers; | |||
| List<KerasHistory> _input_coordinates; | |||
| List<KerasHistory> _output_coordinates; | |||
| public string[] NetworkNodes { get; set; } | |||
| public Dictionary<int, List<Node>> NodesByDepth { get; set; } | |||
| public List<Layer> Layers { get; set; } | |||
| Dictionary<int, int> tensor_usage_count; | |||
| public Dictionary<int, int> TensorUsageCount => tensor_usage_count; | |||
| public Functional(Tensors inputs, Tensors outputs) | |||
| : base(new ModelArgs | |||
| @@ -33,6 +39,7 @@ namespace Tensorflow.Keras.Engine | |||
| _output_layers = new List<Layer>(); | |||
| _input_coordinates = new List<KerasHistory>(); | |||
| _output_coordinates = new List<KerasHistory>(); | |||
| tensor_usage_count = new Dictionary<int, int>(); | |||
| _init_graph_network(inputs, outputs); | |||
| } | |||
| @@ -67,16 +74,253 @@ namespace Tensorflow.Keras.Engine | |||
| _input_layers.append(layer); | |||
| _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); | |||
| } | |||
| // Keep track of the network's nodes and layers. | |||
| var (nodes, nodes_by_depth, layers, _) = MapGraphNetwork(inputs, outputs); | |||
| NetworkNodes = nodes; | |||
| NodesByDepth = nodes_by_depth; | |||
| Layers = layers; | |||
| ComputeTensorUsageCount(); | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| void ComputeTensorUsageCount() | |||
| { | |||
| return run_internal_graph(inputs, state, is_training); | |||
| var available_tensors = inputs.Select(x => x.GetHashCode()).ToList(); | |||
| var depth_keys = NodesByDepth.Keys.Reverse().Skip(1).ToArray(); | |||
| foreach(var depth in depth_keys) | |||
| { | |||
| foreach(var node in NodesByDepth[depth]) | |||
| { | |||
| var input_tensors = node.KerasInputs.Select(x => x.GetHashCode()).ToArray(); | |||
| if (input_tensors.issubset(available_tensors)) | |||
| { | |||
| foreach (var tensor in node.KerasInputs) | |||
| { | |||
| if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) | |||
| tensor_usage_count[tensor.GetHashCode()] = 0; | |||
| tensor_usage_count[tensor.GetHashCode()] += 1; | |||
| } | |||
| foreach (var output_tensor in node.Outputs) | |||
| available_tensors.Add(output_tensor.GetHashCode()); | |||
| } | |||
| } | |||
| } | |||
| foreach (var tensor in outputs) | |||
| { | |||
| if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) | |||
| tensor_usage_count[tensor.GetHashCode()] = 0; | |||
| tensor_usage_count[tensor.GetHashCode()] += 1; | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Validates a network's topology and gather its layers and nodes. | |||
| /// </summary> | |||
| /// <param name="inputs"></param> | |||
| /// <param name="outputs"></param> | |||
| (string[], Dictionary<int, List<Node>>, List<Layer>, Dictionary<int, List<Layer>>) MapGraphNetwork(Tensors inputs, Tensors outputs) | |||
| { | |||
| var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs); | |||
| var network_nodes = nodes_in_decreasing_depth | |||
| .Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node))) | |||
| .ToArray(); | |||
| var nodes_depths = new Dictionary<Node, int>(); | |||
| var layers_depths = new Dictionary<Layer, int>(); | |||
| nodes_in_decreasing_depth.Reverse(); | |||
| foreach (var node in nodes_in_decreasing_depth) | |||
| { | |||
| // If the depth is not set, the node has no outbound nodes (depth 0). | |||
| int depth = nodes_depths.SetDefault(node, 0); | |||
| // Update the depth of the corresponding layer | |||
| int previous_depth = layers_depths.Get(node.Layer, 0); | |||
| // If we've seen this layer before at a higher depth, | |||
| // we should use that depth instead of the node depth. | |||
| // This is necessary for shared layers that have inputs at different | |||
| // depth levels in the graph. | |||
| depth = Math.Max(depth, previous_depth); | |||
| layers_depths[node.Layer] = depth; | |||
| nodes_depths[node] = depth; | |||
| // Update the depth of inbound nodes. | |||
| // The "depth" of a node is the max of the depths | |||
| // of all nodes it is connected to + 1. | |||
| foreach(var node_dep in node.ParentNodes) | |||
| { | |||
| previous_depth = nodes_depths.Get(node_dep, 0); | |||
| nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth); | |||
| } | |||
| } | |||
| // Handle inputs that are not connected to outputs. | |||
| // We do not error out here because the inputs may be used to compute losses | |||
| // and metrics. | |||
| foreach(var input_t in inputs) | |||
| { | |||
| var (input_layer, _, _) = input_t.KerasHistory; | |||
| if (!layers_depths.ContainsKey(input_layer)) | |||
| { | |||
| layers_depths[input_layer] = 0; | |||
| layer_indices[input_layer] = -1; | |||
| nodes_depths[input_layer.InboundNodes[0]] = 0; | |||
| network_nodes.add(MakeNodeKey(input_layer.Name, 0)); | |||
| } | |||
| } | |||
| // Build a dict {depth: list of nodes with this depth} | |||
| var nodes_by_depth = new Dictionary<int, List<Node>>(); | |||
| foreach (var node in nodes_depths) | |||
| { | |||
| if (!nodes_by_depth.ContainsKey(node.Value)) | |||
| nodes_by_depth[node.Value] = new List<Node>(); | |||
| nodes_by_depth[node.Value].append(node.Key); | |||
| } | |||
| var layers_by_depth = new Dictionary<int, List<Layer>>(); | |||
| foreach (var layer in layers_depths) | |||
| { | |||
| if (!layers_by_depth.ContainsKey(layer.Value)) | |||
| layers_by_depth[layer.Value] = new List<Layer>(); | |||
| layers_by_depth[layer.Value].append(layer.Key); | |||
| } | |||
| // Get sorted list of layer depths. | |||
| var depth_keys = layers_by_depth.Keys.Reverse(); | |||
| // Set self.layers ordered by depth. | |||
| var layers = new List<Layer>(); | |||
| foreach(var depth in depth_keys) | |||
| { | |||
| var layers_for_depth = layers_by_depth[depth]; | |||
| // Network.layers needs to have a deterministic order: | |||
| // here we order them by traversal order. | |||
| layers_for_depth.Reverse(); | |||
| layers.AddRange(layers_for_depth); | |||
| } | |||
| // Get sorted list of node depths. | |||
| depth_keys = nodes_by_depth.Keys.Reverse(); | |||
| return (network_nodes, nodes_by_depth, layers, layers_by_depth); | |||
| } | |||
| Tensors run_internal_graph(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| string MakeNodeKey(string layer_name, int node_index) | |||
| => $"{layer_name}_ib-{node_index}"; | |||
| /// <summary> | |||
| /// This method topologically sorts nodes in order from inputs to outputs. | |||
| /// </summary> | |||
| /// <param name="outputs"></param> | |||
| (List<Node>, Dictionary<Layer, int>) BuildMap(Tensors outputs) | |||
| { | |||
| var finished_nodes = new List<Node>(); | |||
| var nodes_in_progress = new List<Node>(); | |||
| var nodes_in_decreasing_depth = new List<Node>(); | |||
| var layer_indices = new Dictionary<Layer, int>(); | |||
| foreach (var output in outputs) | |||
| BuildMapHelper(output, | |||
| finished_nodes, | |||
| nodes_in_progress, | |||
| nodes_in_decreasing_depth, | |||
| layer_indices); | |||
| return (nodes_in_decreasing_depth, layer_indices); | |||
| } | |||
| void BuildMapHelper(Tensor tensor, | |||
| List<Node> finished_nodes, | |||
| List<Node> nodes_in_progress, | |||
| List<Node> nodes_in_decreasing_depth, | |||
| Dictionary<Layer, int> layer_indices) | |||
| { | |||
| var (layer, node_index, _) = tensor.KerasHistory; | |||
| var node = layer.InboundNodes[node_index]; | |||
| // Don't repeat work for shared subgraphs | |||
| if (finished_nodes.Contains(node)) | |||
| return; | |||
| // Prevent cycles. | |||
| if (nodes_in_progress.Contains(node)) | |||
| throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle."); | |||
| // Store the traversal order for layer sorting. | |||
| if (!layer_indices.ContainsKey(layer)) | |||
| layer_indices[layer] = layer_indices.Count; | |||
| // Propagate to all previous tensors connected to this node. | |||
| nodes_in_progress.Add(node); | |||
| foreach (var k_tensor in node.KerasInputs) | |||
| BuildMapHelper(k_tensor, | |||
| finished_nodes, | |||
| nodes_in_progress, | |||
| nodes_in_decreasing_depth, | |||
| layer_indices); | |||
| finished_nodes.Add(node); | |||
| nodes_in_progress.Remove(node); | |||
| nodes_in_decreasing_depth.Insert(nodes_in_decreasing_depth.Count, node); | |||
| } | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| return run_internal_graph(inputs, is_training); | |||
| } | |||
| Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) | |||
| { | |||
| if (mask != null) | |||
| { | |||
| Tensor[] masks = new Tensor[inputs.Count()]; | |||
| foreach (var (i, input_t) in enumerate(inputs)) | |||
| input_t.KerasMask = masks[i]; | |||
| } | |||
| var tensor_dict = new Dictionary<int, Tensor[]>(); | |||
| foreach (var (x, y) in zip(this.inputs, inputs)) | |||
| { | |||
| var y1 = conform_to_reference_input(y, x); | |||
| var x_id = x.GetHashCode(); | |||
| tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1).ToArray(); | |||
| } | |||
| var depth_keys = NodesByDepth.Keys.Reverse().ToArray(); | |||
| foreach(var depth in depth_keys) | |||
| { | |||
| var nodes = NodesByDepth[depth]; | |||
| foreach(var node in nodes) | |||
| { | |||
| // Input tensors already exist. | |||
| if (node.IsInput) | |||
| continue; | |||
| var layer_inputs = new Tensors(tensor_dict[node.FlatInputIds[0]]); | |||
| tensor_dict[node.FlatInputIds[0]] = new Tensor[0]; | |||
| var outputs = node.Layer.Apply(layer_inputs, is_training: training); | |||
| // Update tensor_dict. | |||
| foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | |||
| tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y).ToArray(); | |||
| } | |||
| } | |||
| foreach(var x in outputs) | |||
| { | |||
| } | |||
| throw new NotImplementedException(""); | |||
| } | |||
| Tensor conform_to_reference_input(Tensor tensor, Tensor ref_input) | |||
| { | |||
| return tensor; | |||
| } | |||
| } | |||
| } | |||
| @@ -9,10 +9,10 @@ namespace Tensorflow.Keras.Engine | |||
| /// </summary> | |||
| public class KerasHistory | |||
| { | |||
| public Layer layer; | |||
| Layer layer; | |||
| int node_index; | |||
| int tensor_index; | |||
| public Tensor tensor; | |||
| Tensor tensor; | |||
| public KerasHistory(Layer layer, int node_index, int tensor_index, Tensor tensor) | |||
| { | |||
| @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine | |||
| if (!built) | |||
| MaybeBuild(inputs); | |||
| outputs = call_fn(inputs, state: state, is_training: is_training); | |||
| outputs = CallFn(inputs, state: state, is_training: is_training); | |||
| outputs = _set_connectivity_metadata_(inputs, outputs); | |||
| _handle_activity_regularization(inputs, outputs); | |||
| @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine | |||
| if (!dynamic) | |||
| throw new NotImplementedException(""); | |||
| outputs = call_fn(inputs); | |||
| outputs = CallFn(inputs); | |||
| outputs = _set_connectivity_metadata_(inputs, outputs); | |||
| _handle_activity_regularization(inputs, outputs); | |||
| @@ -162,7 +162,7 @@ namespace Tensorflow.Keras.Engine | |||
| /// <param name="state"></param> | |||
| /// <param name="is_training"></param> | |||
| /// <returns></returns> | |||
| protected virtual Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected virtual Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| @@ -39,20 +39,42 @@ namespace Tensorflow.Keras.Engine | |||
| public Tensors Outputs => args.Outputs; | |||
| public TensorShape[] input_shapes; | |||
| public TensorShape[] output_shapes; | |||
| List<Tensor> kerasInputs = new List<Tensor>(); | |||
| public List<Tensor> KerasInputs = new List<Tensor>(); | |||
| public Layer Layer { get; set; } | |||
| public bool IsInput => args.InputTensors == null; | |||
| public int[] FlatInputIds { get; set; } | |||
| public int[] FlatOutputIds { get; set; } | |||
| public Node[] ParentNodes | |||
| { | |||
| get | |||
| { | |||
| var node_deps = new List<Node>(); | |||
| foreach(var kt in KerasInputs) | |||
| { | |||
| var (layer, node_index, _) = kt.KerasHistory; | |||
| if (layer != null) | |||
| node_deps.append(layer.InboundNodes[node_index]); | |||
| } | |||
| return node_deps.ToArray(); | |||
| } | |||
| } | |||
| public Node(Layer layer, NodeArgs args) | |||
| { | |||
| this.args = args; | |||
| this.Layer = layer; | |||
| if (args.InputTensors != null) | |||
| kerasInputs.AddRange(args.InputTensors); | |||
| KerasInputs.AddRange(args.InputTensors); | |||
| // Wire up Node to Layers. | |||
| layer.InboundNodes.Add(this); | |||
| foreach (var kt in kerasInputs) | |||
| foreach (var kt in KerasInputs) | |||
| { | |||
| var inbound_layer = kt.KerasHistory.layer; | |||
| if (kt.KerasHistory == null) | |||
| continue; | |||
| var (inbound_layer, _, _) = kt.KerasHistory; | |||
| if (inbound_layer != null) | |||
| inbound_layer.OutboundNodes.Add(this); | |||
| } | |||
| @@ -61,6 +83,10 @@ namespace Tensorflow.Keras.Engine | |||
| var node_index = layer.InboundNodes.Count - 1; | |||
| foreach (var (i, tensor) in enumerate(Outputs)) | |||
| tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); | |||
| // Cached for performance. | |||
| FlatInputIds = KerasInputs.Select(x => x.GetHashCode()).ToArray(); | |||
| FlatOutputIds = Outputs.Select(x => x.GetHashCode()).ToArray(); | |||
| } | |||
| } | |||
| } | |||
| @@ -23,9 +23,9 @@ namespace Tensorflow.Keras.Engine | |||
| built = true; | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| return base.call_fn(inputs, state, is_training); | |||
| return base.CallFn(inputs, state, is_training); | |||
| } | |||
| } | |||
| } | |||
| @@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| Tensor outputs = null; | |||
| @@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) | |||
| { | |||
| var outputs = _convolution_op.Apply(inputs, kernel); | |||
| if (use_bias) | |||
| @@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) | |||
| { | |||
| Tensor outputs = null; | |||
| var rank = inputs.rank; | |||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| var output = tf_utils.smart_cond(is_training, | |||
| () => tf.nn.dropout(inputs, | |||
| @@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers | |||
| built = true; | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| var dtype = inputs.dtype; | |||
| if (dtype != tf.int32 && dtype != tf.int64) | |||
| @@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers | |||
| .ToArray(); | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| return base.call_fn(inputs, state: state, is_training: is_training); | |||
| return base.CallFn(inputs, state: state, is_training: is_training); | |||
| } | |||
| } | |||
| } | |||
| @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers | |||
| input_spec = new InputSpec(ndim: 4); | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| int[] pool_shape; | |||
| int[] strides; | |||
| @@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.args = args; | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| scale = math_ops.cast(args.Scale, args.DType); | |||
| offset = math_ops.cast(args.Offset, args.DType); | |||
| @@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers | |||
| this.input_spec = new InputSpec(ndim: 4); | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| return tf.keras.backend.spatial_2d_padding(inputs, | |||
| padding: padding, | |||
| @@ -74,7 +74,7 @@ namespace Tensorflow | |||
| /// <param name="training"></param> | |||
| /// <param name="state"></param> | |||
| /// <returns></returns> | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| var one = constant_op.constant(1, dtype: dtypes.int32); | |||
| // Parameters of gates are concatenated into one multiply for efficiency. | |||
| @@ -67,7 +67,7 @@ namespace Tensorflow | |||
| built = true; | |||
| } | |||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||
| { | |||
| // Most basic RNN: output = new_state = act(W * input + U * state + B). | |||
| var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); | |||
| @@ -145,6 +145,7 @@ namespace Tensorflow | |||
| /// Keras History: (Layer, (node_index, tensor_index)) | |||
| /// </summary> | |||
| public KerasHistory KerasHistory { get; set; } | |||
| public Tensor KerasMask { get; set; } | |||
| /// <summary> | |||
| /// Updates the shape of this tensor. | |||