diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs index 3f7b1836..c6896ad7 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs @@ -62,14 +62,17 @@ namespace Tensorflow.Checkpoint return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); } - public unsafe Tensor GetTensor(string name) + public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) { Status status = new Status(); var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); status.Check(true); var shape = GetVariableShape(name); - var dtype = GetVariableDataType(name); - return new Tensor(c_api.TF_TensorData(tensor), shape, dtype); + if(dtype == TF_DataType.DtInvalid) + { + dtype = GetVariableDataType(name); + } + return new Tensor(tensor); } private void ReadAllShapeAndType() diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index d5cf2ae4..1934ffd5 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -227,7 +227,7 @@ public class TrackableSaver { dtype_map = reader.VariableToDataTypeMap; } - Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY); + Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); Dictionary file_prefix_feed_dict; Tensor file_prefix_tensor; @@ -249,7 +249,14 @@ public class TrackableSaver file_prefix_feed_dict = null; } TrackableObjectGraph object_graph_proto = new(); - object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); + if(object_graph_string.ndim > 0) + { + object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); + } + else + { + object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]); + } CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( object_graph_proto: object_graph_proto, save_path: save_path, diff --git a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs index bac9cedb..a6720a5f 100644 --- a/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs +++ b/src/TensorFlowNET.Core/Functions/ConcreteFunction.cs @@ -13,8 +13,8 @@ namespace Tensorflow.Functions /// public class ConcreteFunction: Trackable { - FuncGraph func_graph; - ForwardBackwardCall forward_backward; + internal FuncGraph func_graph; + internal ForwardBackwardCall forward_backward; public Tensor[] Inputs => func_graph.Inputs; public Tensor[] CapturedInputs => func_graph.external_captures; @@ -23,6 +23,8 @@ namespace Tensorflow.Functions public Tensor[] Outputs; public Type ReturnType; public TensorSpec[] OutputStructure; + public IEnumerable ArgKeywords { get; set; } + public long NumPositionArgs { get; set; } public ConcreteFunction(string name) { @@ -163,6 +165,15 @@ namespace Tensorflow.Functions return flat_outputs; } + public void AddTograph(Graph? g = null) + { + if(!tf.Context.executing_eagerly() && g is null) + { + g = ops.get_default_graph(); + } + // TODO(Rinne); complete it with `_delayed_rewrite_functions`. + } + ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) { var functions = new FirstOrderTapeGradientFunctions(func_graph, false); diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs new file mode 100644 index 00000000..e9086ae9 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs @@ -0,0 +1,39 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedDTypeJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(TF_DataType); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + var token = JToken.FromObject(value); + token.WriteTo(writer); + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + if (reader.ValueType == typeof(string)) + { + var str = (string)serializer.Deserialize(reader, typeof(string)); + return dtypes.tf_dtype_from_name(str); + } + else + { + return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType)); + } + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs index 1ad19fc8..cfd8ee8f 100644 --- a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs @@ -46,7 +46,16 @@ namespace Tensorflow.Keras.Common { throw new ValueError("Cannot deserialize 'null' to `Shape`."); } - if(values.Length != 3) + if(values.Length == 1) + { + var array = values[0] as JArray; + if(array is null) + { + throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); + } + values = array.ToObject(); + } + if (values.Length < 3) { throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); } @@ -54,19 +63,37 @@ namespace Tensorflow.Keras.Common { throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); } - if (values[1] is not int) + int nodeIndex; + int tensorIndex; + if (values[1] is long) + { + nodeIndex = (int)(long)values[1]; + } + else if (values[1] is int) + { + nodeIndex = (int)values[1]; + } + else { throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); } - if (values[2] is not int) + if (values[2] is long) + { + tensorIndex = (int)(long)values[2]; + } + else if (values[1] is int) + { + tensorIndex = (int)values[2]; + } + else { throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); } return new NodeConfig() { Name = values[0] as string, - NodeIndex = (int)values[1], - TensorIndex = (int)values[2] + NodeIndex = nodeIndex, + TensorIndex = tensorIndex }; } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs index cac19180..934d3b15 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs @@ -1,8 +1,11 @@ using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; namespace Tensorflow.Keras.Saving { diff --git a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs index 5fe28c5d..0f514b42 100644 --- a/src/TensorFlowNET.Core/Tensors/TF_DataType.cs +++ b/src/TensorFlowNET.Core/Tensors/TF_DataType.cs @@ -1,9 +1,13 @@ -namespace Tensorflow +using Newtonsoft.Json; +using Tensorflow.Keras.Common; + +namespace Tensorflow { /// /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. /// The enum values here are identical to corresponding values in types.proto. /// + [JsonConverter(typeof(CustomizedDTypeJsonConverter))] public enum TF_DataType { DtInvalid = 0, diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index deeb9e4b..3563f91a 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -159,7 +159,10 @@ namespace Tensorflow "uint32" => TF_DataType.TF_UINT32, "int64" => TF_DataType.TF_INT64, "uint64" => TF_DataType.TF_UINT64, + "float16" => TF_DataType.TF_BFLOAT16, + "float32" => TF_DataType.TF_FLOAT, "single" => TF_DataType.TF_FLOAT, + "float64" => TF_DataType.TF_DOUBLE, "double" => TF_DataType.TF_DOUBLE, "complex" => TF_DataType.TF_COMPLEX128, "string" => TF_DataType.TF_STRING, diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs new file mode 100644 index 00000000..341a12ab --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Functions; + +namespace Tensorflow.Training.Saving.SavedModel +{ + /// + /// A class wraps a concrete function to handle different distributed contexts. + /// + internal class WrapperFunction: ConcreteFunction + { + public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) + { + this.forward_backward = concrete_function.forward_backward; + this.Outputs = concrete_function.Outputs; + this.ReturnType = concrete_function.ReturnType; + this.OutputStructure = concrete_function.OutputStructure; + this.ArgKeywords = concrete_function.ArgKeywords; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs new file mode 100644 index 00000000..5b482872 --- /dev/null +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Functions; +using Tensorflow.Util; + +namespace Tensorflow.Training.Saving.SavedModel +{ + public static class function_deserialization + { + public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, + IDictionary concrete_functions) + { + var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName]; + concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); + concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; + + var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); + concrete_function.AddTograph(); + return concrete_function; + } + + private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto) + { + // TODO(Rinne); revise the implementation. + return new FunctionSpec() + { + Fullargspec = function_spec_proto.Fullargspec, + IsMethod = function_spec_proto.IsMethod, + InputSignature = function_spec_proto.InputSignature, + JitCompile = function_spec_proto.JitCompile + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index 9e2654a7..da999b37 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -12,6 +12,7 @@ using static Tensorflow.Binding; using System.Runtime.CompilerServices; using Tensorflow.Variables; using Tensorflow.Functions; +using Tensorflow.Training.Saving.SavedModel; namespace Tensorflow { @@ -307,6 +308,11 @@ namespace Tensorflow foreach(var (node_id, proto) in _iter_all_nodes()) { var node = get(node_id); + if(node is null) + { + // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. + continue; + } if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) { // Restore Trackable serialize- and restore-from-tensor functions. @@ -376,6 +382,13 @@ namespace Tensorflow } else { + // skip the function and concrete function. + if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) + { + nodes[node_id] = null; + node_setters[node_id] = null; + continue; + } var (node, setter) = _recreate(proto, node_id, nodes); nodes[node_id] = node; node_setters[node_id] = setter; @@ -480,6 +493,11 @@ namespace Tensorflow foreach(var refer in proto.Children) { + if(obj is null) + { + // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. + continue; + } setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); // skip the process of "__call__" } @@ -591,6 +609,13 @@ namespace Tensorflow } } + private (ConcreteFunction, Action) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, + Dictionary, Trackable> dependencies) + { + throw new NotImplementedException(); + //var fn = function_deserialization.setup_bare_concrete_function(proto, ) + } + // TODO: remove this to a common class. public static Action setattr = (x, y, z) => { diff --git a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs index 2ea1b82e..f4407265 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs @@ -24,10 +24,10 @@ namespace Tensorflow.Keras.Engine /// /// /// - public static (Tensors, Tensors, Dictionary) reconstruct_from_config(ModelConfig config) + public static (Tensors, Tensors, Dictionary) reconstruct_from_config(ModelConfig config, Dictionary? created_layers = null) { // Layer instances created during the graph reconstruction process. - var created_layers = new Dictionary(); + created_layers = created_layers ?? new Dictionary(); var node_index_map = new Dictionary<(string, int), int>(); var node_count_by_layer = new Dictionary(); var unprocessed_nodes = new Dictionary(); @@ -88,12 +88,7 @@ namespace Tensorflow.Keras.Engine layer = created_layers[layer_name]; else { - layer = layer_data.ClassName switch - { - "InputLayer" => InputLayer.from_config(layer_data.Config), - "Dense" => Dense.from_config(layer_data.Config), - _ => throw new NotImplementedException("") - }; + layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config); created_layers[layer_name] = layer; } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs index fc405d87..ed5c2de0 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs @@ -12,7 +12,7 @@ public abstract partial class Layer public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; - public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; + public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata; public override IDictionary _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary>? cache = null) { diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index 81f3a7d9..e54b939f 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -96,7 +96,6 @@ namespace Tensorflow.Keras.Engine List inboundNodes; public List InboundNodes => inboundNodes; - List outboundNodes; public List OutboundNodes => outboundNodes; diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index ca8007d0..56fde9f2 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -85,10 +85,5 @@ namespace Tensorflow.Keras.Layers return outputs; } - - public static Dense from_config(LayerArgs args) - { - return new Dense(args as DenseArgs); - } } } diff --git a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs index 03b4b742..a44c0bde 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs @@ -102,11 +102,6 @@ namespace Tensorflow.Keras.Layers name: Name); } - public static InputLayer from_config(LayerArgs args) - { - return new InputLayer(args as InputLayerArgs); - } - public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); } } diff --git a/src/TensorFlowNET.Keras/Models/ModelsApi.cs b/src/TensorFlowNET.Keras/Models/ModelsApi.cs index 73b77bc4..6597f5cd 100644 --- a/src/TensorFlowNET.Keras/Models/ModelsApi.cs +++ b/src/TensorFlowNET.Keras/Models/ModelsApi.cs @@ -4,6 +4,7 @@ using System.IO; using System.Text; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Saving; +using Tensorflow.Keras.Saving.SavedModel; using ThirdParty.Tensorflow.Python.Keras.Protobuf; namespace Tensorflow.Keras.Models @@ -13,20 +14,9 @@ namespace Tensorflow.Keras.Models public Functional from_config(ModelConfig config) => Functional.from_config(config); - public void load_model(string filepath, bool compile = true) + public Model load_model(string filepath, bool compile = true, LoadOptions? options = null) { - var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb")); - var saved_mode = SavedModel.Parser.ParseFrom(bytes); - - var meta_graph_def = saved_mode.MetaGraphs[0]; - var object_graph_def = meta_graph_def.ObjectGraphDef; - - bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb")); - var metadata = SavedMetadata.Parser.ParseFrom(bytes); - - // Recreate layers and metrics using the info stored in the metadata. - var keras_loader = new KerasObjectLoader(metadata, object_graph_def); - keras_loader.load_layers(compile: compile); + return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; } } } diff --git a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs index cf9e4652..eb167b94 100644 --- a/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs +++ b/src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs @@ -164,11 +164,11 @@ namespace Tensorflow.Keras.Saving { if (config["layers"][0]["class_name"].ToObject() == "InputLayer") { - layers.Insert(0, InputLayer.from_config(config["layers"][0]["config"].ToObject())); + layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject())); } else if (config["layers"][0]["config"]["batch_input_shape"] is not null) { - // TODO: implement it + // TODO(Rinne): implement it } } @@ -192,7 +192,8 @@ namespace Tensorflow.Keras.Saving else { // skip the parameter `created_layers`. - var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(config.ToObject()); + var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config), + layers.ToDictionary(x => x.Name, x => x as ILayer)); // skip the `model.__init__` (model as Functional).Initialize(inputs, outputs, config["name"].ToObject()); (model as Functional).connect_ancillary_layers(created_layers); @@ -283,7 +284,6 @@ namespace Tensorflow.Keras.Saving private (Trackable, Action) _load_layer(int node_id, string identifier, string metadata_json) { - metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); var metadata = JsonConvert.DeserializeObject(metadata_json); if (loaded_nodes.ContainsKey(node_id)) diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index b4e5a889..60ca6332 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -95,7 +95,7 @@ public partial class KerasSavedModelUtils BadConsumers = { } }, Identifier = layer.ObjectIdentifier, - Metadata = layer.TrackingMetadata + Metadata = layer.GetTrackingMetadata() }; metadata.Nodes.Add(saved_object); diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs index e7cb5b3a..abb2012f 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/load.cs @@ -44,7 +44,7 @@ namespace Tensorflow.Keras.Saving.SavedModel } } - public static Trackable load(string path, bool compile = true, LoadOptions? options = null) + private static Trackable load(string path, bool compile = true, LoadOptions? options = null) { SavedMetadata metadata = new SavedMetadata(); var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; @@ -82,12 +82,12 @@ namespace Tensorflow.Keras.Saving.SavedModel if(model is Model && compile) { - // TODO: implement it. + // TODO(Rinne): implement it. } if (!tf.Context.executing_eagerly()) { - // TODO: implement it. + // TODO(Rinne): implement it. } return model; diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index fffa4b8a..ccc8aca2 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Utils return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); } - public static Layer deserialize_keras_object(string class_name, JObject config) + public static Layer deserialize_keras_object(string class_name, JToken config) { return class_name switch { @@ -70,6 +70,58 @@ namespace Tensorflow.Keras.Utils }; } + public static Layer deserialize_keras_object(string class_name, LayerArgs args) + { + return class_name switch + { + "Sequential" => new Sequential(args as SequentialArgs), + "InputLayer" => new InputLayer(args as InputLayerArgs), + "Flatten" => new Flatten(args as FlattenArgs), + "ELU" => new ELU(args as ELUArgs), + "Dense" => new Dense(args as DenseArgs), + "Softmax" => new Softmax(args as SoftmaxArgs), + _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + + $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") + }; + } + + public static LayerArgs? deserialize_layer_args(string class_name, JToken config) + { + return class_name switch + { + "Sequential" => config.ToObject(), + "InputLayer" => config.ToObject(), + "Flatten" => config.ToObject(), + "ELU" => config.ToObject(), + "Dense" => config.ToObject(), + "Softmax" => config.ToObject(), + _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + + $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") + }; + } + + public static ModelConfig deserialize_model_config(JToken json) + { + ModelConfig config = new ModelConfig(); + config.Name = json["name"].ToObject(); + config.Layers = new List(); + var layersToken = json["layers"]; + foreach (var token in layersToken) + { + var args = deserialize_layer_args(token["class_name"].ToObject(), token["config"]); + config.Layers.Add(new LayerConfig() + { + Config = args, + Name = token["name"].ToObject(), + ClassName = token["class_name"].ToObject(), + InboundNodes = token["inbound_nodes"].ToObject>() + }); + } + config.InputLayers = json["input_layers"].ToObject>(); + config.OutputLayers = json["output_layers"].ToObject>(); + return config; + } + public static string to_snake_case(string name) { return string.Concat(name.Select((x, i) => diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index 57a69249..672f8d09 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -21,17 +21,16 @@ public class SequentialModelLoad [TestMethod] public void SimpleModelFromSequential() { - var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/model.pb"); - Debug.Assert(model is Model); - var m = model as Model; + new SequentialModelSave().SimpleModelFromSequential(); + var model = keras.models.load_model(@"./pb_simple_sequential"); - m.summary(); + model.summary(); - m.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); var data_loader = new MnistModelLoader(); var num_epochs = 1; - var batch_size = 50; + var batch_size = 8; var dataset = data_loader.LoadAsync(new ModelLoadSetting { @@ -40,6 +39,6 @@ public class SequentialModelLoad ValidationSize = 50000, }).Result; - m.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); } } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs index d24049fb..efefa9a0 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs @@ -1,27 +1,21 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; -using Tensorflow.NumPy; -using System; using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; +using System.Diagnostics; using Tensorflow; -using static Tensorflow.Binding; -using static Tensorflow.KerasApi; using Tensorflow.Keras; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Losses; -using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Optimizers; -using Tensorflow.Operations; -using System.Diagnostics; +using Tensorflow.NumPy; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; namespace TensorFlowNET.Keras.UnitTest.SaveModel; [TestClass] -public class SequentialModelTest +public class SequentialModelSave { [TestMethod] public void SimpleModelFromAutoCompile() @@ -118,7 +112,7 @@ public class SequentialModelTest keras.layers.Softmax(1) }); - model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); var num_epochs = 1; var batch_size = 8;