From 01e88bb8bbbf062af3f8cd85720aa8cf9d06ca3a Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Wed, 1 Mar 2023 15:30:36 +0800 Subject: [PATCH] Revise customized json converters. --- .../Common/CustomizedAxisJsonConverter.cs | 11 +++++++++- .../Common/CustomizedShapeJsonConverter.cs | 22 +++++++++++++++++-- .../Training/Saving/SavedModel/loader.cs | 1 + src/TensorFlowNET.Keras/Engine/Functional.cs | 9 +++++++- .../Saving/KerasObjectLoader.cs | 15 ++++++++----- .../SaveModel/SequentialModelLoad.cs | 2 +- 6 files changed, 50 insertions(+), 10 deletions(-) diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs index dfd8735b..f6087a43 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs @@ -37,7 +37,16 @@ namespace Tensorflow.Keras.Common public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) { - var axis = serializer.Deserialize(reader, typeof(int[])); + int[]? axis; + if(reader.ValueType == typeof(long)) + { + axis = new int[1]; + axis[0] = (int)serializer.Deserialize(reader, typeof(int)); + } + else + { + axis = serializer.Deserialize(reader, typeof(int[])) as int[]; + } if (axis is null) { throw new ValueError("Cannot deserialize 'null' to `Axis`."); diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs index 300cb2f2..c7812eec 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs @@ -51,8 +51,26 @@ namespace Tensorflow.Keras.Common public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) { - var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; - if(dims is null) + long?[] dims; + try + { + dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; + } + catch (JsonSerializationException ex) + { + if (reader.Value.Equals("class_name")) + { + reader.Read(); + reader.Read(); + reader.Read(); + dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; + } + else + { + throw ex; + } + } + if (dims is null) { throw new ValueError("Cannot deserialize 'null' to `Shape`."); } diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index 9595ba11..9e2654a7 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -11,6 +11,7 @@ using pbc = global::Google.Protobuf.Collections; using static Tensorflow.Binding; using System.Runtime.CompilerServices; using Tensorflow.Variables; +using Tensorflow.Functions; namespace Tensorflow { diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 61eae06e..33320101 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -75,7 +75,14 @@ namespace Tensorflow.Keras.Engine this.inputs = inputs; this.outputs = outputs; built = true; - _buildInputShape = inputs.shape; + if(inputs.Length > 0) + { + _buildInputShape = inputs.shape; + } + else + { + _buildInputShape = new Saving.TensorShapeConfig(); + } if (outputs.Any(x => x.KerasHistory == null)) base_layer_utils.create_keras_history(outputs); diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index b378ea64..cf9e4652 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -72,6 +72,10 @@ namespace Tensorflow.Keras.Saving { try { + if (node_metadata.Identifier.Equals("_tf_keras_metric")) + { + continue; + } loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata); } @@ -324,7 +328,9 @@ namespace Tensorflow.Keras.Saving Trackable obj; if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) { - throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); + // TODO(Rinne): implement it. + return (null, null); + //throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); } else { @@ -343,7 +349,7 @@ namespace Tensorflow.Keras.Saving private (Trackable, Action) _revive_custom_object(string identifier, KerasMetaData metadata) { - // TODO: implement it. + // TODO(Rinne): implement it. throw new NotImplementedException(); } @@ -367,15 +373,14 @@ namespace Tensorflow.Keras.Saving } else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) { - model = model = new Sequential(new SequentialArgs + model = new Sequential(new SequentialArgs { Name = class_name }); } else { - // TODO: implement it. - throw new NotImplementedException("Not implemented"); + model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject()); } // Record this model and its layers. This will later be used to reconstruct diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index d43b1358..57a69249 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -21,7 +21,7 @@ public class SequentialModelLoad [TestMethod] public void SimpleModelFromSequential() { - var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/tf.net.simple.sequential"); + var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/model.pb"); Debug.Assert(model is Model); var m = model as Model;