| @@ -22,6 +22,7 @@ using System.ComponentModel; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| using NumSharp.Utilities; | using NumSharp.Utilities; | ||||
| using System.Runtime.CompilerServices; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -50,7 +51,7 @@ namespace Tensorflow | |||||
| => list.Add(element); | => list.Add(element); | ||||
| public static void append<T>(this IList<T> list, T element) | public static void append<T>(this IList<T> list, T element) | ||||
| => list.Add(element); | |||||
| => list.Insert(list.Count, element); | |||||
| public static T[] concat<T>(this IList<T> list1, IList<T> list2) | public static T[] concat<T>(this IList<T> list1, IList<T> list2) | ||||
| { | { | ||||
| @@ -407,5 +408,37 @@ namespace Tensorflow | |||||
| return true; | return true; | ||||
| return false; | return false; | ||||
| } | } | ||||
| public static bool issubset<T>(this IEnumerable<T> subset, IEnumerable<T> src) | |||||
| { | |||||
| bool issubset = true; | |||||
| foreach (var element in subset) | |||||
| { | |||||
| if (!src.Contains(element)) | |||||
| { | |||||
| issubset = false; | |||||
| continue; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| public static TValue SetDefault<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||||
| { | |||||
| if (dic.ContainsKey(key)) | |||||
| return dic[key]; | |||||
| dic[key] = value; | |||||
| return value; | |||||
| } | |||||
| public static TValue Get<TKey, TValue>(this Dictionary<TKey, TValue> dic, TKey key, TValue value) | |||||
| { | |||||
| if (dic.ContainsKey(key)) | |||||
| return dic[key]; | |||||
| return value; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -21,7 +21,7 @@ namespace Tensorflow.Keras.Engine | |||||
| _channels_first = args.DataFormat == "channels_first"; | _channels_first = args.DataFormat == "channels_first"; | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| if (_channels_first) | if (_channels_first) | ||||
| { | { | ||||
| @@ -4,6 +4,7 @@ using System.Linq; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| @@ -21,6 +22,11 @@ namespace Tensorflow.Keras.Engine | |||||
| List<Layer> _input_layers; | List<Layer> _input_layers; | ||||
| List<KerasHistory> _input_coordinates; | List<KerasHistory> _input_coordinates; | ||||
| List<KerasHistory> _output_coordinates; | List<KerasHistory> _output_coordinates; | ||||
| public string[] NetworkNodes { get; set; } | |||||
| public Dictionary<int, List<Node>> NodesByDepth { get; set; } | |||||
| public List<Layer> Layers { get; set; } | |||||
| Dictionary<int, int> tensor_usage_count; | |||||
| public Dictionary<int, int> TensorUsageCount => tensor_usage_count; | |||||
| public Functional(Tensors inputs, Tensors outputs) | public Functional(Tensors inputs, Tensors outputs) | ||||
| : base(new ModelArgs | : base(new ModelArgs | ||||
| @@ -33,6 +39,7 @@ namespace Tensorflow.Keras.Engine | |||||
| _output_layers = new List<Layer>(); | _output_layers = new List<Layer>(); | ||||
| _input_coordinates = new List<KerasHistory>(); | _input_coordinates = new List<KerasHistory>(); | ||||
| _output_coordinates = new List<KerasHistory>(); | _output_coordinates = new List<KerasHistory>(); | ||||
| tensor_usage_count = new Dictionary<int, int>(); | |||||
| _init_graph_network(inputs, outputs); | _init_graph_network(inputs, outputs); | ||||
| } | } | ||||
| @@ -67,16 +74,253 @@ namespace Tensorflow.Keras.Engine | |||||
| _input_layers.append(layer); | _input_layers.append(layer); | ||||
| _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); | _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x)); | ||||
| } | } | ||||
| // Keep track of the network's nodes and layers. | |||||
| var (nodes, nodes_by_depth, layers, _) = MapGraphNetwork(inputs, outputs); | |||||
| NetworkNodes = nodes; | |||||
| NodesByDepth = nodes_by_depth; | |||||
| Layers = layers; | |||||
| ComputeTensorUsageCount(); | |||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| void ComputeTensorUsageCount() | |||||
| { | { | ||||
| return run_internal_graph(inputs, state, is_training); | |||||
| var available_tensors = inputs.Select(x => x.GetHashCode()).ToList(); | |||||
| var depth_keys = NodesByDepth.Keys.Reverse().Skip(1).ToArray(); | |||||
| foreach(var depth in depth_keys) | |||||
| { | |||||
| foreach(var node in NodesByDepth[depth]) | |||||
| { | |||||
| var input_tensors = node.KerasInputs.Select(x => x.GetHashCode()).ToArray(); | |||||
| if (input_tensors.issubset(available_tensors)) | |||||
| { | |||||
| foreach (var tensor in node.KerasInputs) | |||||
| { | |||||
| if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) | |||||
| tensor_usage_count[tensor.GetHashCode()] = 0; | |||||
| tensor_usage_count[tensor.GetHashCode()] += 1; | |||||
| } | |||||
| foreach (var output_tensor in node.Outputs) | |||||
| available_tensors.Add(output_tensor.GetHashCode()); | |||||
| } | |||||
| } | |||||
| } | |||||
| foreach (var tensor in outputs) | |||||
| { | |||||
| if (!tensor_usage_count.ContainsKey(tensor.GetHashCode())) | |||||
| tensor_usage_count[tensor.GetHashCode()] = 0; | |||||
| tensor_usage_count[tensor.GetHashCode()] += 1; | |||||
| } | |||||
| } | |||||
| /// <summary> | |||||
| /// Validates a network's topology and gather its layers and nodes. | |||||
| /// </summary> | |||||
| /// <param name="inputs"></param> | |||||
| /// <param name="outputs"></param> | |||||
| (string[], Dictionary<int, List<Node>>, List<Layer>, Dictionary<int, List<Layer>>) MapGraphNetwork(Tensors inputs, Tensors outputs) | |||||
| { | |||||
| var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs); | |||||
| var network_nodes = nodes_in_decreasing_depth | |||||
| .Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node))) | |||||
| .ToArray(); | |||||
| var nodes_depths = new Dictionary<Node, int>(); | |||||
| var layers_depths = new Dictionary<Layer, int>(); | |||||
| nodes_in_decreasing_depth.Reverse(); | |||||
| foreach (var node in nodes_in_decreasing_depth) | |||||
| { | |||||
| // If the depth is not set, the node has no outbound nodes (depth 0). | |||||
| int depth = nodes_depths.SetDefault(node, 0); | |||||
| // Update the depth of the corresponding layer | |||||
| int previous_depth = layers_depths.Get(node.Layer, 0); | |||||
| // If we've seen this layer before at a higher depth, | |||||
| // we should use that depth instead of the node depth. | |||||
| // This is necessary for shared layers that have inputs at different | |||||
| // depth levels in the graph. | |||||
| depth = Math.Max(depth, previous_depth); | |||||
| layers_depths[node.Layer] = depth; | |||||
| nodes_depths[node] = depth; | |||||
| // Update the depth of inbound nodes. | |||||
| // The "depth" of a node is the max of the depths | |||||
| // of all nodes it is connected to + 1. | |||||
| foreach(var node_dep in node.ParentNodes) | |||||
| { | |||||
| previous_depth = nodes_depths.Get(node_dep, 0); | |||||
| nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth); | |||||
| } | |||||
| } | |||||
| // Handle inputs that are not connected to outputs. | |||||
| // We do not error out here because the inputs may be used to compute losses | |||||
| // and metrics. | |||||
| foreach(var input_t in inputs) | |||||
| { | |||||
| var (input_layer, _, _) = input_t.KerasHistory; | |||||
| if (!layers_depths.ContainsKey(input_layer)) | |||||
| { | |||||
| layers_depths[input_layer] = 0; | |||||
| layer_indices[input_layer] = -1; | |||||
| nodes_depths[input_layer.InboundNodes[0]] = 0; | |||||
| network_nodes.add(MakeNodeKey(input_layer.Name, 0)); | |||||
| } | |||||
| } | |||||
| // Build a dict {depth: list of nodes with this depth} | |||||
| var nodes_by_depth = new Dictionary<int, List<Node>>(); | |||||
| foreach (var node in nodes_depths) | |||||
| { | |||||
| if (!nodes_by_depth.ContainsKey(node.Value)) | |||||
| nodes_by_depth[node.Value] = new List<Node>(); | |||||
| nodes_by_depth[node.Value].append(node.Key); | |||||
| } | |||||
| var layers_by_depth = new Dictionary<int, List<Layer>>(); | |||||
| foreach (var layer in layers_depths) | |||||
| { | |||||
| if (!layers_by_depth.ContainsKey(layer.Value)) | |||||
| layers_by_depth[layer.Value] = new List<Layer>(); | |||||
| layers_by_depth[layer.Value].append(layer.Key); | |||||
| } | |||||
| // Get sorted list of layer depths. | |||||
| var depth_keys = layers_by_depth.Keys.Reverse(); | |||||
| // Set self.layers ordered by depth. | |||||
| var layers = new List<Layer>(); | |||||
| foreach(var depth in depth_keys) | |||||
| { | |||||
| var layers_for_depth = layers_by_depth[depth]; | |||||
| // Network.layers needs to have a deterministic order: | |||||
| // here we order them by traversal order. | |||||
| layers_for_depth.Reverse(); | |||||
| layers.AddRange(layers_for_depth); | |||||
| } | |||||
| // Get sorted list of node depths. | |||||
| depth_keys = nodes_by_depth.Keys.Reverse(); | |||||
| return (network_nodes, nodes_by_depth, layers, layers_by_depth); | |||||
| } | } | ||||
| Tensors run_internal_graph(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| string MakeNodeKey(string layer_name, int node_index) | |||||
| => $"{layer_name}_ib-{node_index}"; | |||||
| /// <summary> | |||||
| /// This method topologically sorts nodes in order from inputs to outputs. | |||||
| /// </summary> | |||||
| /// <param name="outputs"></param> | |||||
| (List<Node>, Dictionary<Layer, int>) BuildMap(Tensors outputs) | |||||
| { | { | ||||
| var finished_nodes = new List<Node>(); | |||||
| var nodes_in_progress = new List<Node>(); | |||||
| var nodes_in_decreasing_depth = new List<Node>(); | |||||
| var layer_indices = new Dictionary<Layer, int>(); | |||||
| foreach (var output in outputs) | |||||
| BuildMapHelper(output, | |||||
| finished_nodes, | |||||
| nodes_in_progress, | |||||
| nodes_in_decreasing_depth, | |||||
| layer_indices); | |||||
| return (nodes_in_decreasing_depth, layer_indices); | |||||
| } | |||||
| void BuildMapHelper(Tensor tensor, | |||||
| List<Node> finished_nodes, | |||||
| List<Node> nodes_in_progress, | |||||
| List<Node> nodes_in_decreasing_depth, | |||||
| Dictionary<Layer, int> layer_indices) | |||||
| { | |||||
| var (layer, node_index, _) = tensor.KerasHistory; | |||||
| var node = layer.InboundNodes[node_index]; | |||||
| // Don't repeat work for shared subgraphs | |||||
| if (finished_nodes.Contains(node)) | |||||
| return; | |||||
| // Prevent cycles. | |||||
| if (nodes_in_progress.Contains(node)) | |||||
| throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle."); | |||||
| // Store the traversal order for layer sorting. | |||||
| if (!layer_indices.ContainsKey(layer)) | |||||
| layer_indices[layer] = layer_indices.Count; | |||||
| // Propagate to all previous tensors connected to this node. | |||||
| nodes_in_progress.Add(node); | |||||
| foreach (var k_tensor in node.KerasInputs) | |||||
| BuildMapHelper(k_tensor, | |||||
| finished_nodes, | |||||
| nodes_in_progress, | |||||
| nodes_in_decreasing_depth, | |||||
| layer_indices); | |||||
| finished_nodes.Add(node); | |||||
| nodes_in_progress.Remove(node); | |||||
| nodes_in_decreasing_depth.Insert(nodes_in_decreasing_depth.Count, node); | |||||
| } | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | |||||
| return run_internal_graph(inputs, is_training); | |||||
| } | |||||
| Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null) | |||||
| { | |||||
| if (mask != null) | |||||
| { | |||||
| Tensor[] masks = new Tensor[inputs.Count()]; | |||||
| foreach (var (i, input_t) in enumerate(inputs)) | |||||
| input_t.KerasMask = masks[i]; | |||||
| } | |||||
| var tensor_dict = new Dictionary<int, Tensor[]>(); | |||||
| foreach (var (x, y) in zip(this.inputs, inputs)) | |||||
| { | |||||
| var y1 = conform_to_reference_input(y, x); | |||||
| var x_id = x.GetHashCode(); | |||||
| tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1).ToArray(); | |||||
| } | |||||
| var depth_keys = NodesByDepth.Keys.Reverse().ToArray(); | |||||
| foreach(var depth in depth_keys) | |||||
| { | |||||
| var nodes = NodesByDepth[depth]; | |||||
| foreach(var node in nodes) | |||||
| { | |||||
| // Input tensors already exist. | |||||
| if (node.IsInput) | |||||
| continue; | |||||
| var layer_inputs = new Tensors(tensor_dict[node.FlatInputIds[0]]); | |||||
| tensor_dict[node.FlatInputIds[0]] = new Tensor[0]; | |||||
| var outputs = node.Layer.Apply(layer_inputs, is_training: training); | |||||
| // Update tensor_dict. | |||||
| foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs)) | |||||
| tensor_dict[x_id] = Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y).ToArray(); | |||||
| } | |||||
| } | |||||
| foreach(var x in outputs) | |||||
| { | |||||
| } | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| Tensor conform_to_reference_input(Tensor tensor, Tensor ref_input) | |||||
| { | |||||
| return tensor; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -9,10 +9,10 @@ namespace Tensorflow.Keras.Engine | |||||
| /// </summary> | /// </summary> | ||||
| public class KerasHistory | public class KerasHistory | ||||
| { | { | ||||
| public Layer layer; | |||||
| Layer layer; | |||||
| int node_index; | int node_index; | ||||
| int tensor_index; | int tensor_index; | ||||
| public Tensor tensor; | |||||
| Tensor tensor; | |||||
| public KerasHistory(Layer layer, int node_index, int tensor_index, Tensor tensor) | public KerasHistory(Layer layer, int node_index, int tensor_index, Tensor tensor) | ||||
| { | { | ||||
| @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine | |||||
| if (!built) | if (!built) | ||||
| MaybeBuild(inputs); | MaybeBuild(inputs); | ||||
| outputs = call_fn(inputs, state: state, is_training: is_training); | |||||
| outputs = CallFn(inputs, state: state, is_training: is_training); | |||||
| outputs = _set_connectivity_metadata_(inputs, outputs); | outputs = _set_connectivity_metadata_(inputs, outputs); | ||||
| _handle_activity_regularization(inputs, outputs); | _handle_activity_regularization(inputs, outputs); | ||||
| @@ -42,7 +42,7 @@ namespace Tensorflow.Keras.Engine | |||||
| if (!dynamic) | if (!dynamic) | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| outputs = call_fn(inputs); | |||||
| outputs = CallFn(inputs); | |||||
| outputs = _set_connectivity_metadata_(inputs, outputs); | outputs = _set_connectivity_metadata_(inputs, outputs); | ||||
| _handle_activity_regularization(inputs, outputs); | _handle_activity_regularization(inputs, outputs); | ||||
| @@ -162,7 +162,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// <param name="state"></param> | /// <param name="state"></param> | ||||
| /// <param name="is_training"></param> | /// <param name="is_training"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected virtual Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected virtual Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| @@ -39,20 +39,42 @@ namespace Tensorflow.Keras.Engine | |||||
| public Tensors Outputs => args.Outputs; | public Tensors Outputs => args.Outputs; | ||||
| public TensorShape[] input_shapes; | public TensorShape[] input_shapes; | ||||
| public TensorShape[] output_shapes; | public TensorShape[] output_shapes; | ||||
| List<Tensor> kerasInputs = new List<Tensor>(); | |||||
| public List<Tensor> KerasInputs = new List<Tensor>(); | |||||
| public Layer Layer { get; set; } | |||||
| public bool IsInput => args.InputTensors == null; | |||||
| public int[] FlatInputIds { get; set; } | |||||
| public int[] FlatOutputIds { get; set; } | |||||
| public Node[] ParentNodes | |||||
| { | |||||
| get | |||||
| { | |||||
| var node_deps = new List<Node>(); | |||||
| foreach(var kt in KerasInputs) | |||||
| { | |||||
| var (layer, node_index, _) = kt.KerasHistory; | |||||
| if (layer != null) | |||||
| node_deps.append(layer.InboundNodes[node_index]); | |||||
| } | |||||
| return node_deps.ToArray(); | |||||
| } | |||||
| } | |||||
| public Node(Layer layer, NodeArgs args) | public Node(Layer layer, NodeArgs args) | ||||
| { | { | ||||
| this.args = args; | this.args = args; | ||||
| this.Layer = layer; | |||||
| if (args.InputTensors != null) | if (args.InputTensors != null) | ||||
| kerasInputs.AddRange(args.InputTensors); | |||||
| KerasInputs.AddRange(args.InputTensors); | |||||
| // Wire up Node to Layers. | // Wire up Node to Layers. | ||||
| layer.InboundNodes.Add(this); | layer.InboundNodes.Add(this); | ||||
| foreach (var kt in kerasInputs) | |||||
| foreach (var kt in KerasInputs) | |||||
| { | { | ||||
| var inbound_layer = kt.KerasHistory.layer; | |||||
| if (kt.KerasHistory == null) | |||||
| continue; | |||||
| var (inbound_layer, _, _) = kt.KerasHistory; | |||||
| if (inbound_layer != null) | if (inbound_layer != null) | ||||
| inbound_layer.OutboundNodes.Add(this); | inbound_layer.OutboundNodes.Add(this); | ||||
| } | } | ||||
| @@ -61,6 +83,10 @@ namespace Tensorflow.Keras.Engine | |||||
| var node_index = layer.InboundNodes.Count - 1; | var node_index = layer.InboundNodes.Count - 1; | ||||
| foreach (var (i, tensor) in enumerate(Outputs)) | foreach (var (i, tensor) in enumerate(Outputs)) | ||||
| tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); | tensor.KerasHistory = new KerasHistory(layer, node_index, i, tensor); | ||||
| // Cached for performance. | |||||
| FlatInputIds = KerasInputs.Select(x => x.GetHashCode()).ToArray(); | |||||
| FlatOutputIds = Outputs.Select(x => x.GetHashCode()).ToArray(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -23,9 +23,9 @@ namespace Tensorflow.Keras.Engine | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| return base.call_fn(inputs, state, is_training); | |||||
| return base.CallFn(inputs, state, is_training); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -119,7 +119,7 @@ namespace Tensorflow.Keras.Layers | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| Tensor outputs = null; | Tensor outputs = null; | ||||
| @@ -98,7 +98,7 @@ namespace Tensorflow.Keras.Layers | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) | |||||
| { | { | ||||
| var outputs = _convolution_op.Apply(inputs, kernel); | var outputs = _convolution_op.Apply(inputs, kernel); | ||||
| if (use_bias) | if (use_bias) | ||||
| @@ -65,7 +65,7 @@ namespace Tensorflow.Keras.Layers | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool training = false) | |||||
| { | { | ||||
| Tensor outputs = null; | Tensor outputs = null; | ||||
| var rank = inputs.rank; | var rank = inputs.rank; | ||||
| @@ -18,7 +18,7 @@ namespace Tensorflow.Keras.Layers | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| var output = tf_utils.smart_cond(is_training, | var output = tf_utils.smart_cond(is_training, | ||||
| () => tf.nn.dropout(inputs, | () => tf.nn.dropout(inputs, | ||||
| @@ -62,7 +62,7 @@ namespace Tensorflow.Keras.Layers | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| var dtype = inputs.dtype; | var dtype = inputs.dtype; | ||||
| if (dtype != tf.int32 && dtype != tf.int64) | if (dtype != tf.int32 && dtype != tf.int64) | ||||
| @@ -29,9 +29,9 @@ namespace Tensorflow.Keras.Layers | |||||
| .ToArray(); | .ToArray(); | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| return base.call_fn(inputs, state: state, is_training: is_training); | |||||
| return base.CallFn(inputs, state: state, is_training: is_training); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -36,7 +36,7 @@ namespace Tensorflow.Keras.Layers | |||||
| input_spec = new InputSpec(ndim: 4); | input_spec = new InputSpec(ndim: 4); | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| int[] pool_shape; | int[] pool_shape; | ||||
| int[] strides; | int[] strides; | ||||
| @@ -20,7 +20,7 @@ namespace Tensorflow.Keras.Layers | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| scale = math_ops.cast(args.Scale, args.DType); | scale = math_ops.cast(args.Scale, args.DType); | ||||
| offset = math_ops.cast(args.Offset, args.DType); | offset = math_ops.cast(args.Offset, args.DType); | ||||
| @@ -29,7 +29,7 @@ namespace Tensorflow.Keras.Layers | |||||
| this.input_spec = new InputSpec(ndim: 4); | this.input_spec = new InputSpec(ndim: 4); | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| return tf.keras.backend.spatial_2d_padding(inputs, | return tf.keras.backend.spatial_2d_padding(inputs, | ||||
| padding: padding, | padding: padding, | ||||
| @@ -74,7 +74,7 @@ namespace Tensorflow | |||||
| /// <param name="training"></param> | /// <param name="training"></param> | ||||
| /// <param name="state"></param> | /// <param name="state"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| var one = constant_op.constant(1, dtype: dtypes.int32); | var one = constant_op.constant(1, dtype: dtypes.int32); | ||||
| // Parameters of gates are concatenated into one multiply for efficiency. | // Parameters of gates are concatenated into one multiply for efficiency. | ||||
| @@ -67,7 +67,7 @@ namespace Tensorflow | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensors call_fn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| protected override Tensors CallFn(Tensors inputs, Tensor state = null, bool is_training = false) | |||||
| { | { | ||||
| // Most basic RNN: output = new_state = act(W * input + U * state + B). | // Most basic RNN: output = new_state = act(W * input + U * state + B). | ||||
| var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); | var concat = array_ops.concat(new Tensor[] { inputs, state }, 1); | ||||
| @@ -145,6 +145,7 @@ namespace Tensorflow | |||||
| /// Keras History: (Layer, (node_index, tensor_index)) | /// Keras History: (Layer, (node_index, tensor_index)) | ||||
| /// </summary> | /// </summary> | ||||
| public KerasHistory KerasHistory { get; set; } | public KerasHistory KerasHistory { get; set; } | ||||
| public Tensor KerasMask { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Updates the shape of this tensor. | /// Updates the shape of this tensor. | ||||