| @@ -37,7 +37,16 @@ namespace Tensorflow.Keras.Common | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
| { | { | ||||
| var axis = serializer.Deserialize(reader, typeof(int[])); | |||||
| int[]? axis; | |||||
| if(reader.ValueType == typeof(long)) | |||||
| { | |||||
| axis = new int[1]; | |||||
| axis[0] = (int)serializer.Deserialize(reader, typeof(int)); | |||||
| } | |||||
| else | |||||
| { | |||||
| axis = serializer.Deserialize(reader, typeof(int[])) as int[]; | |||||
| } | |||||
| if (axis is null) | if (axis is null) | ||||
| { | { | ||||
| throw new ValueError("Cannot deserialize 'null' to `Axis`."); | throw new ValueError("Cannot deserialize 'null' to `Axis`."); | ||||
| @@ -51,8 +51,26 @@ namespace Tensorflow.Keras.Common | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | ||||
| { | { | ||||
| var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
| if(dims is null) | |||||
| long?[] dims; | |||||
| try | |||||
| { | |||||
| dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
| } | |||||
| catch (JsonSerializationException ex) | |||||
| { | |||||
| if (reader.Value.Equals("class_name")) | |||||
| { | |||||
| reader.Read(); | |||||
| reader.Read(); | |||||
| reader.Read(); | |||||
| dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw ex; | |||||
| } | |||||
| } | |||||
| if (dims is null) | |||||
| { | { | ||||
| throw new ValueError("Cannot deserialize 'null' to `Shape`."); | throw new ValueError("Cannot deserialize 'null' to `Shape`."); | ||||
| } | } | ||||
| @@ -11,6 +11,7 @@ using pbc = global::Google.Protobuf.Collections; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using Tensorflow.Variables; | using Tensorflow.Variables; | ||||
| using Tensorflow.Functions; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -75,7 +75,14 @@ namespace Tensorflow.Keras.Engine | |||||
| this.inputs = inputs; | this.inputs = inputs; | ||||
| this.outputs = outputs; | this.outputs = outputs; | ||||
| built = true; | built = true; | ||||
| _buildInputShape = inputs.shape; | |||||
| if(inputs.Length > 0) | |||||
| { | |||||
| _buildInputShape = inputs.shape; | |||||
| } | |||||
| else | |||||
| { | |||||
| _buildInputShape = new Saving.TensorShapeConfig(); | |||||
| } | |||||
| if (outputs.Any(x => x.KerasHistory == null)) | if (outputs.Any(x => x.KerasHistory == null)) | ||||
| base_layer_utils.create_keras_history(outputs); | base_layer_utils.create_keras_history(outputs); | ||||
| @@ -72,6 +72,10 @@ namespace Tensorflow.Keras.Saving | |||||
| { | { | ||||
| try | try | ||||
| { | { | ||||
| if (node_metadata.Identifier.Equals("_tf_keras_metric")) | |||||
| { | |||||
| continue; | |||||
| } | |||||
| loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, | loaded_nodes[node_metadata.NodeId] = _load_layer(node_metadata.NodeId, node_metadata.Identifier, | ||||
| node_metadata.Metadata); | node_metadata.Metadata); | ||||
| } | } | ||||
| @@ -324,7 +328,9 @@ namespace Tensorflow.Keras.Saving | |||||
| Trackable obj; | Trackable obj; | ||||
| if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | if(identifier == Keras.Saving.SavedModel.Constants.METRIC_IDENTIFIER) | ||||
| { | { | ||||
| throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
| // TODO(Rinne): implement it. | |||||
| return (null, null); | |||||
| //throw new NotImplementedException("Not implemented, please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -343,7 +349,7 @@ namespace Tensorflow.Keras.Saving | |||||
| private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | private (Trackable, Action<object, object, object>) _revive_custom_object(string identifier, KerasMetaData metadata) | ||||
| { | { | ||||
| // TODO: implement it. | |||||
| // TODO(Rinne): implement it. | |||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| @@ -367,15 +373,14 @@ namespace Tensorflow.Keras.Saving | |||||
| } | } | ||||
| else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) | else if(identifier == Keras.Saving.SavedModel.Constants.SEQUENTIAL_IDENTIFIER) | ||||
| { | { | ||||
| model = model = new Sequential(new SequentialArgs | |||||
| model = new Sequential(new SequentialArgs | |||||
| { | { | ||||
| Name = class_name | Name = class_name | ||||
| }); | }); | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| // TODO: implement it. | |||||
| throw new NotImplementedException("Not implemented"); | |||||
| model = new Functional(new Tensors(), new Tensors(), config["name"].ToObject<string>()); | |||||
| } | } | ||||
| // Record this model and its layers. This will later be used to reconstruct | // Record this model and its layers. This will later be used to reconstruct | ||||
| @@ -21,7 +21,7 @@ public class SequentialModelLoad | |||||
| [TestMethod] | [TestMethod] | ||||
| public void SimpleModelFromSequential() | public void SimpleModelFromSequential() | ||||
| { | { | ||||
| var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/tf.net.simple.sequential"); | |||||
| var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/model.pb"); | |||||
| Debug.Assert(model is Model); | Debug.Assert(model is Model); | ||||
| var m = model as Model; | var m = model as Model; | ||||