fix: Implemented support for loading models with Concatenate layerstags/v0.150.0-BERT-Model
| @@ -1,13 +1,15 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| // TODO: complete the implementation | // TODO: complete the implementation | ||||
| public class MergeArgs : LayerArgs | |||||
| public class MergeArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| public Tensors Inputs { get; set; } | public Tensors Inputs { get; set; } | ||||
| [JsonProperty("axis")] | |||||
| public int Axis { get; set; } | public int Axis { get; set; } | ||||
| } | } | ||||
| } | } | ||||
| @@ -30,6 +30,15 @@ namespace Tensorflow.NumPy | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays)); | public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays)); | ||||
| [AutoNumPy] | |||||
| public static NDArray stack(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.stack(arrays, axis)); | |||||
| [AutoNumPy] | |||||
| public static NDArray stack((NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2 }, axis)); | |||||
| [AutoNumPy] | |||||
| public static NDArray stack((NDArray, NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2, tuple.Item3 }, axis)); | |||||
| [AutoNumPy] | [AutoNumPy] | ||||
| public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination)); | public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination)); | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine | |||||
| created_layers = created_layers ?? new Dictionary<string, ILayer>(); | created_layers = created_layers ?? new Dictionary<string, ILayer>(); | ||||
| var node_index_map = new Dictionary<(string, int), int>(); | var node_index_map = new Dictionary<(string, int), int>(); | ||||
| var node_count_by_layer = new Dictionary<ILayer, int>(); | var node_count_by_layer = new Dictionary<ILayer, int>(); | ||||
| var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); | |||||
| var unprocessed_nodes = new Dictionary<ILayer, List<NodeConfig>>(); | |||||
| // First, we create all layers and enqueue nodes to be processed | // First, we create all layers and enqueue nodes to be processed | ||||
| foreach (var layer_data in config.Layers) | foreach (var layer_data in config.Layers) | ||||
| process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer); | process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer); | ||||
| @@ -79,7 +79,7 @@ namespace Tensorflow.Keras.Engine | |||||
| static void process_layer(Dictionary<string, ILayer> created_layers, | static void process_layer(Dictionary<string, ILayer> created_layers, | ||||
| LayerConfig layer_data, | LayerConfig layer_data, | ||||
| Dictionary<ILayer, NodeConfig> unprocessed_nodes, | |||||
| Dictionary<ILayer, List<NodeConfig>> unprocessed_nodes, | |||||
| Dictionary<ILayer, int> node_count_by_layer) | Dictionary<ILayer, int> node_count_by_layer) | ||||
| { | { | ||||
| ILayer layer = null; | ILayer layer = null; | ||||
| @@ -92,32 +92,38 @@ namespace Tensorflow.Keras.Engine | |||||
| created_layers[layer_name] = layer; | created_layers[layer_name] = layer; | ||||
| } | } | ||||
| node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0; | |||||
| node_count_by_layer[layer] = layer_data.InboundNodes.Count - (_should_skip_first_node(layer) ? 1 : 0); | |||||
| var inbound_nodes_data = layer_data.InboundNodes; | var inbound_nodes_data = layer_data.InboundNodes; | ||||
| foreach (var node_data in inbound_nodes_data) | foreach (var node_data in inbound_nodes_data) | ||||
| { | { | ||||
| if (!unprocessed_nodes.ContainsKey(layer)) | if (!unprocessed_nodes.ContainsKey(layer)) | ||||
| unprocessed_nodes[layer] = node_data; | |||||
| unprocessed_nodes[layer] = new List<NodeConfig>() { node_data }; | |||||
| else | else | ||||
| unprocessed_nodes.Add(layer, node_data); | |||||
| unprocessed_nodes[layer].Add(node_data); | |||||
| } | } | ||||
| } | } | ||||
| static void process_node(ILayer layer, | static void process_node(ILayer layer, | ||||
| NodeConfig node_data, | |||||
| List<NodeConfig> nodes_data, | |||||
| Dictionary<string, ILayer> created_layers, | Dictionary<string, ILayer> created_layers, | ||||
| Dictionary<ILayer, int> node_count_by_layer, | Dictionary<ILayer, int> node_count_by_layer, | ||||
| Dictionary<(string, int), int> node_index_map) | Dictionary<(string, int), int> node_index_map) | ||||
| { | { | ||||
| var input_tensors = new List<Tensor>(); | var input_tensors = new List<Tensor>(); | ||||
| var inbound_layer_name = node_data.Name; | |||||
| var inbound_node_index = node_data.NodeIndex; | |||||
| var inbound_tensor_index = node_data.TensorIndex; | |||||
| var inbound_layer = created_layers[inbound_layer_name]; | |||||
| var inbound_node = inbound_layer.InboundNodes[inbound_node_index]; | |||||
| input_tensors.Add(inbound_node.Outputs[inbound_node_index]); | |||||
| for (int i = 0; i < nodes_data.Count; i++) | |||||
| { | |||||
| var node_data = nodes_data[i]; | |||||
| var inbound_layer_name = node_data.Name; | |||||
| var inbound_node_index = node_data.NodeIndex; | |||||
| var inbound_tensor_index = node_data.TensorIndex; | |||||
| var inbound_layer = created_layers[inbound_layer_name]; | |||||
| var inbound_node = inbound_layer.InboundNodes[inbound_node_index]; | |||||
| input_tensors.Add(inbound_node.Outputs[inbound_node_index]); | |||||
| } | |||||
| var output_tensors = layer.Apply(input_tensors); | var output_tensors = layer.Apply(input_tensors); | ||||
| @@ -39,6 +39,7 @@ namespace Tensorflow.Keras.Layers | |||||
| shape_set.Add(shape); | shape_set.Add(shape); | ||||
| }*/ | }*/ | ||||
| _buildInputShape = input_shape; | _buildInputShape = input_shape; | ||||
| built = true; | |||||
| } | } | ||||
| protected override Tensors _merge_function(Tensors inputs) | protected override Tensors _merge_function(Tensors inputs) | ||||
| @@ -112,12 +112,23 @@ namespace Tensorflow.Keras.Utils | |||||
| foreach (var token in layersToken) | foreach (var token in layersToken) | ||||
| { | { | ||||
| var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]); | var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]); | ||||
| List<NodeConfig> nodeConfig = null; //python tensorflow sometimes exports inbound nodes in an extra nested array | |||||
| if (token["inbound_nodes"].Count() > 0 && token["inbound_nodes"][0].Count() > 0 && token["inbound_nodes"][0][0].Count() > 0) | |||||
| { | |||||
| nodeConfig = token["inbound_nodes"].ToObject<List<List<NodeConfig>>>().FirstOrDefault() ?? new List<NodeConfig>(); | |||||
| } | |||||
| else | |||||
| { | |||||
| nodeConfig = token["inbound_nodes"].ToObject<List<NodeConfig>>(); | |||||
| } | |||||
| config.Layers.Add(new LayerConfig() | config.Layers.Add(new LayerConfig() | ||||
| { | { | ||||
| Config = args, | Config = args, | ||||
| Name = token["name"].ToObject<string>(), | Name = token["name"].ToObject<string>(), | ||||
| ClassName = token["class_name"].ToObject<string>(), | ClassName = token["class_name"].ToObject<string>(), | ||||
| InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>() | |||||
| InboundNodes = nodeConfig, | |||||
| }); | }); | ||||
| } | } | ||||
| config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>(); | config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>(); | ||||
| @@ -1,4 +1,5 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System.Collections.Generic; | |||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| @@ -8,12 +9,16 @@ namespace Tensorflow.Keras.UnitTest.Layers | |||||
| public class LayersMergingTest : EagerModeTestBase | public class LayersMergingTest : EagerModeTestBase | ||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void Concatenate() | |||||
| [DataRow(1, 4, 1, 5)] | |||||
| [DataRow(2, 2, 2, 5)] | |||||
| [DataRow(3, 2, 1, 10)] | |||||
| public void Concatenate(int axis, int shapeA, int shapeB, int shapeC) | |||||
| { | { | ||||
| var x = np.arange(20).reshape((2, 2, 5)); | |||||
| var y = np.arange(20, 30).reshape((2, 1, 5)); | |||||
| var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y)); | |||||
| Assert.AreEqual((2, 3, 5), z.shape); | |||||
| var x = np.arange(10).reshape((1, 2, 1, 5)); | |||||
| var y = np.arange(10, 20).reshape((1, 2, 1, 5)); | |||||
| var z = keras.layers.Concatenate(axis: axis).Apply(new Tensors(x, y)); | |||||
| Assert.AreEqual((1, shapeA, shapeB, shapeC), z.shape); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,13 @@ | |||||
| using Microsoft.VisualStudio.TestPlatform.Utilities; | using Microsoft.VisualStudio.TestPlatform.Utilities; | ||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Newtonsoft.Json.Linq; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Xml.Linq; | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| using Tensorflow.Keras.UnitTest.Helpers; | using Tensorflow.Keras.UnitTest.Helpers; | ||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| using static HDF.PInvoke.H5Z; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| @@ -124,4 +127,44 @@ public class ModelLoadTest | |||||
| var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model; | var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model; | ||||
| model.summary(); | model.summary(); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void CreateConcatenateModelSaveAndLoad() | |||||
| { | |||||
| // a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded. | |||||
| var input_layer = tf.keras.layers.Input((8, 8, 5)); | |||||
| var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer); | |||||
| conv1.Name = "conv1"; | |||||
| var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer); | |||||
| conv2.Name = "conv2"; | |||||
| var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2)); | |||||
| concat1.Name = "concat1"; | |||||
| var model = tf.keras.Model(input_layer, concat1); | |||||
| model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy()); | |||||
| model.save(@"Assets/concat_axis3_model"); | |||||
| var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT); | |||||
| var tensors1 = model.predict(tensorInput); | |||||
| Assert.AreEqual((1, 8, 8, 4), tensors1.shape); | |||||
| model = null; | |||||
| keras.backend.clear_session(); | |||||
| var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model"); | |||||
| var tensors2 = model2.predict(tensorInput); | |||||
| Assert.AreEqual(tensors1.shape, tensors2.shape); | |||||
| } | |||||
| } | } | ||||