Browse Source

Merge pull request #1192 from Jucko13/master

fix: Implemented support for loading models with Concatenate layers
tags/v0.150.0-BERT-Model
Haiping GitHub 2 years ago
parent
commit
9e3654bf9c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 97 additions and 20 deletions
  1. +4
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs
  2. +9
    -0
      src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs
  3. +18
    -12
      src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
  4. +1
    -0
      src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
  5. +12
    -1
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  6. +10
    -5
      test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs
  7. +43
    -0
      test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs

+ 4
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs View File

@@ -1,13 +1,15 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
// TODO: complete the implementation // TODO: complete the implementation
public class MergeArgs : LayerArgs
public class MergeArgs : AutoSerializeLayerArgs
{ {
public Tensors Inputs { get; set; } public Tensors Inputs { get; set; }
[JsonProperty("axis")]
public int Axis { get; set; } public int Axis { get; set; }
} }
} }

+ 9
- 0
src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs View File

@@ -30,6 +30,15 @@ namespace Tensorflow.NumPy
[AutoNumPy] [AutoNumPy]
public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays)); public static NDArray stack(params NDArray[] arrays) => new NDArray(array_ops.stack(arrays));


[AutoNumPy]
public static NDArray stack(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.stack(arrays, axis));
[AutoNumPy]
public static NDArray stack((NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2 }, axis));

[AutoNumPy]
public static NDArray stack((NDArray, NDArray, NDArray) tuple, int axis = 0) => new NDArray(array_ops.stack(new[] { tuple.Item1, tuple.Item2, tuple.Item3 }, axis));

[AutoNumPy] [AutoNumPy]
public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination)); public static NDArray moveaxis(NDArray array, Axis source, Axis destination) => new NDArray(array_ops.moveaxis(array, source, destination));
} }


+ 18
- 12
src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs View File

@@ -30,7 +30,7 @@ namespace Tensorflow.Keras.Engine
created_layers = created_layers ?? new Dictionary<string, ILayer>(); created_layers = created_layers ?? new Dictionary<string, ILayer>();
var node_index_map = new Dictionary<(string, int), int>(); var node_index_map = new Dictionary<(string, int), int>();
var node_count_by_layer = new Dictionary<ILayer, int>(); var node_count_by_layer = new Dictionary<ILayer, int>();
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
var unprocessed_nodes = new Dictionary<ILayer, List<NodeConfig>>();
// First, we create all layers and enqueue nodes to be processed // First, we create all layers and enqueue nodes to be processed
foreach (var layer_data in config.Layers) foreach (var layer_data in config.Layers)
process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer); process_layer(created_layers, layer_data, unprocessed_nodes, node_count_by_layer);
@@ -79,7 +79,7 @@ namespace Tensorflow.Keras.Engine


static void process_layer(Dictionary<string, ILayer> created_layers, static void process_layer(Dictionary<string, ILayer> created_layers,
LayerConfig layer_data, LayerConfig layer_data,
Dictionary<ILayer, NodeConfig> unprocessed_nodes,
Dictionary<ILayer, List<NodeConfig>> unprocessed_nodes,
Dictionary<ILayer, int> node_count_by_layer) Dictionary<ILayer, int> node_count_by_layer)
{ {
ILayer layer = null; ILayer layer = null;
@@ -92,32 +92,38 @@ namespace Tensorflow.Keras.Engine


created_layers[layer_name] = layer; created_layers[layer_name] = layer;
} }
node_count_by_layer[layer] = _should_skip_first_node(layer) ? 1 : 0;
node_count_by_layer[layer] = layer_data.InboundNodes.Count - (_should_skip_first_node(layer) ? 1 : 0);


var inbound_nodes_data = layer_data.InboundNodes; var inbound_nodes_data = layer_data.InboundNodes;
foreach (var node_data in inbound_nodes_data) foreach (var node_data in inbound_nodes_data)
{ {
if (!unprocessed_nodes.ContainsKey(layer)) if (!unprocessed_nodes.ContainsKey(layer))
unprocessed_nodes[layer] = node_data;
unprocessed_nodes[layer] = new List<NodeConfig>() { node_data };
else else
unprocessed_nodes.Add(layer, node_data);
unprocessed_nodes[layer].Add(node_data);
} }
} }


static void process_node(ILayer layer, static void process_node(ILayer layer,
NodeConfig node_data,
List<NodeConfig> nodes_data,
Dictionary<string, ILayer> created_layers, Dictionary<string, ILayer> created_layers,
Dictionary<ILayer, int> node_count_by_layer, Dictionary<ILayer, int> node_count_by_layer,
Dictionary<(string, int), int> node_index_map) Dictionary<(string, int), int> node_index_map)
{ {

var input_tensors = new List<Tensor>(); var input_tensors = new List<Tensor>();
var inbound_layer_name = node_data.Name;
var inbound_node_index = node_data.NodeIndex;
var inbound_tensor_index = node_data.TensorIndex;


var inbound_layer = created_layers[inbound_layer_name];
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
for (int i = 0; i < nodes_data.Count; i++)
{
var node_data = nodes_data[i];
var inbound_layer_name = node_data.Name;
var inbound_node_index = node_data.NodeIndex;
var inbound_tensor_index = node_data.TensorIndex;

var inbound_layer = created_layers[inbound_layer_name];
var inbound_node = inbound_layer.InboundNodes[inbound_node_index];
input_tensors.Add(inbound_node.Outputs[inbound_node_index]);
}


var output_tensors = layer.Apply(input_tensors); var output_tensors = layer.Apply(input_tensors);




+ 1
- 0
src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs View File

@@ -39,6 +39,7 @@ namespace Tensorflow.Keras.Layers
shape_set.Add(shape); shape_set.Add(shape);
}*/ }*/
_buildInputShape = input_shape; _buildInputShape = input_shape;
built = true;
} }


protected override Tensors _merge_function(Tensors inputs) protected override Tensors _merge_function(Tensors inputs)


+ 12
- 1
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -112,12 +112,23 @@ namespace Tensorflow.Keras.Utils
foreach (var token in layersToken) foreach (var token in layersToken)
{ {
var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]); var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]);

List<NodeConfig> nodeConfig = null; //python tensorflow sometimes exports inbound nodes in an extra nested array
if (token["inbound_nodes"].Count() > 0 && token["inbound_nodes"][0].Count() > 0 && token["inbound_nodes"][0][0].Count() > 0)
{
nodeConfig = token["inbound_nodes"].ToObject<List<List<NodeConfig>>>().FirstOrDefault() ?? new List<NodeConfig>();
}
else
{
nodeConfig = token["inbound_nodes"].ToObject<List<NodeConfig>>();
}

config.Layers.Add(new LayerConfig() config.Layers.Add(new LayerConfig()
{ {
Config = args, Config = args,
Name = token["name"].ToObject<string>(), Name = token["name"].ToObject<string>(),
ClassName = token["class_name"].ToObject<string>(), ClassName = token["class_name"].ToObject<string>(),
InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>()
InboundNodes = nodeConfig,
}); });
} }
config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>(); config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>();


+ 10
- 5
test/TensorFlowNET.Keras.UnitTest/Layers/Layers.Merging.Test.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using System.Collections.Generic;
using Tensorflow.NumPy; using Tensorflow.NumPy;
using static Tensorflow.KerasApi; using static Tensorflow.KerasApi;


@@ -8,12 +9,16 @@ namespace Tensorflow.Keras.UnitTest.Layers
public class LayersMergingTest : EagerModeTestBase public class LayersMergingTest : EagerModeTestBase
{ {
[TestMethod] [TestMethod]
public void Concatenate()
[DataRow(1, 4, 1, 5)]
[DataRow(2, 2, 2, 5)]
[DataRow(3, 2, 1, 10)]
public void Concatenate(int axis, int shapeA, int shapeB, int shapeC)
{ {
var x = np.arange(20).reshape((2, 2, 5));
var y = np.arange(20, 30).reshape((2, 1, 5));
var z = keras.layers.Concatenate(axis: 1).Apply(new Tensors(x, y));
Assert.AreEqual((2, 3, 5), z.shape);
var x = np.arange(10).reshape((1, 2, 1, 5));
var y = np.arange(10, 20).reshape((1, 2, 1, 5));
var z = keras.layers.Concatenate(axis: axis).Apply(new Tensors(x, y));
Assert.AreEqual((1, shapeA, shapeB, shapeC), z.shape);
} }

} }
} }

+ 43
- 0
test/TensorFlowNET.Keras.UnitTest/Model/ModelLoadTest.cs View File

@@ -1,10 +1,13 @@
using Microsoft.VisualStudio.TestPlatform.Utilities; using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Newtonsoft.Json.Linq;
using System.Linq; using System.Linq;
using System.Xml.Linq;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.UnitTest.Helpers; using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow.NumPy; using Tensorflow.NumPy;
using static HDF.PInvoke.H5Z;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using static Tensorflow.KerasApi; using static Tensorflow.KerasApi;


@@ -124,4 +127,44 @@ public class ModelLoadTest
var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model; var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Tensorflow.Keras.Engine.Model;
model.summary(); model.summary();
} }



[TestMethod]
public void CreateConcatenateModelSaveAndLoad()
{
// a small demo model that is just here to see if the axis value for the concatenate method is saved and loaded.
var input_layer = tf.keras.layers.Input((8, 8, 5));

var conv1 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_1"*/).Apply(input_layer);
conv1.Name = "conv1";

var conv2 = tf.keras.layers.Conv2D(2, kernel_size: 3, activation: "relu", padding: "same"/*, data_format: "_conv_2"*/).Apply(input_layer);
conv2.Name = "conv2";

var concat1 = tf.keras.layers.Concatenate(axis: 3).Apply((conv1, conv2));
concat1.Name = "concat1";

var model = tf.keras.Model(input_layer, concat1);
model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.CategoricalCrossentropy());

model.save(@"Assets/concat_axis3_model");

var tensorInput = np.arange(320).reshape((1, 8, 8, 5)).astype(TF_DataType.TF_FLOAT);

var tensors1 = model.predict(tensorInput);

Assert.AreEqual((1, 8, 8, 4), tensors1.shape);

model = null;
keras.backend.clear_session();

var model2 = tf.keras.models.load_model(@"Assets/concat_axis3_model");

var tensors2 = model2.predict(tensorInput);

Assert.AreEqual(tensors1.shape, tensors2.shape);
}

} }

Loading…
Cancel
Save