| @@ -62,14 +62,17 @@ namespace Tensorflow.Checkpoint | |||||
| return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); | return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); | ||||
| } | } | ||||
| public unsafe Tensor GetTensor(string name) | |||||
| public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||||
| { | { | ||||
| Status status = new Status(); | Status status = new Status(); | ||||
| var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); | var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); | ||||
| status.Check(true); | status.Check(true); | ||||
| var shape = GetVariableShape(name); | var shape = GetVariableShape(name); | ||||
| var dtype = GetVariableDataType(name); | |||||
| return new Tensor(c_api.TF_TensorData(tensor), shape, dtype); | |||||
| if(dtype == TF_DataType.DtInvalid) | |||||
| { | |||||
| dtype = GetVariableDataType(name); | |||||
| } | |||||
| return new Tensor(tensor); | |||||
| } | } | ||||
| private void ReadAllShapeAndType() | private void ReadAllShapeAndType() | ||||
| @@ -227,7 +227,7 @@ public class TrackableSaver | |||||
| { | { | ||||
| dtype_map = reader.VariableToDataTypeMap; | dtype_map = reader.VariableToDataTypeMap; | ||||
| } | } | ||||
| Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY); | |||||
| Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | |||||
| Dictionary<Tensor, string> file_prefix_feed_dict; | Dictionary<Tensor, string> file_prefix_feed_dict; | ||||
| Tensor file_prefix_tensor; | Tensor file_prefix_tensor; | ||||
| @@ -249,7 +249,14 @@ public class TrackableSaver | |||||
| file_prefix_feed_dict = null; | file_prefix_feed_dict = null; | ||||
| } | } | ||||
| TrackableObjectGraph object_graph_proto = new(); | TrackableObjectGraph object_graph_proto = new(); | ||||
| object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); | |||||
| if(object_graph_string.ndim > 0) | |||||
| { | |||||
| object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); | |||||
| } | |||||
| else | |||||
| { | |||||
| object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]); | |||||
| } | |||||
| CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( | CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( | ||||
| object_graph_proto: object_graph_proto, | object_graph_proto: object_graph_proto, | ||||
| save_path: save_path, | save_path: save_path, | ||||
| @@ -13,8 +13,8 @@ namespace Tensorflow.Functions | |||||
| /// </summary> | /// </summary> | ||||
| public class ConcreteFunction: Trackable | public class ConcreteFunction: Trackable | ||||
| { | { | ||||
| FuncGraph func_graph; | |||||
| ForwardBackwardCall forward_backward; | |||||
| internal FuncGraph func_graph; | |||||
| internal ForwardBackwardCall forward_backward; | |||||
| public Tensor[] Inputs => func_graph.Inputs; | public Tensor[] Inputs => func_graph.Inputs; | ||||
| public Tensor[] CapturedInputs => func_graph.external_captures; | public Tensor[] CapturedInputs => func_graph.external_captures; | ||||
| @@ -23,6 +23,8 @@ namespace Tensorflow.Functions | |||||
| public Tensor[] Outputs; | public Tensor[] Outputs; | ||||
| public Type ReturnType; | public Type ReturnType; | ||||
| public TensorSpec[] OutputStructure; | public TensorSpec[] OutputStructure; | ||||
| public IEnumerable<string> ArgKeywords { get; set; } | |||||
| public long NumPositionArgs { get; set; } | |||||
| public ConcreteFunction(string name) | public ConcreteFunction(string name) | ||||
| { | { | ||||
| @@ -163,6 +165,15 @@ namespace Tensorflow.Functions | |||||
| return flat_outputs; | return flat_outputs; | ||||
| } | } | ||||
| public void AddTograph(Graph? g = null) | |||||
| { | |||||
| if(!tf.Context.executing_eagerly() && g is null) | |||||
| { | |||||
| g = ops.get_default_graph(); | |||||
| } | |||||
| // TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||||
| } | |||||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | ||||
| { | { | ||||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | ||||
| @@ -0,0 +1,39 @@ | |||||
| using Newtonsoft.Json.Linq; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Common | |||||
| { | |||||
| public class CustomizedDTypeJsonConverter : JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(TF_DataType); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| var token = JToken.FromObject(value); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| if (reader.ValueType == typeof(string)) | |||||
| { | |||||
| var str = (string)serializer.Deserialize(reader, typeof(string)); | |||||
| return dtypes.tf_dtype_from_name(str); | |||||
| } | |||||
| else | |||||
| { | |||||
| return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType)); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -46,7 +46,16 @@ namespace Tensorflow.Keras.Common | |||||
| { | { | ||||
| throw new ValueError("Cannot deserialize 'null' to `Shape`."); | throw new ValueError("Cannot deserialize 'null' to `Shape`."); | ||||
| } | } | ||||
| if(values.Length != 3) | |||||
| if(values.Length == 1) | |||||
| { | |||||
| var array = values[0] as JArray; | |||||
| if(array is null) | |||||
| { | |||||
| throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||||
| } | |||||
| values = array.ToObject<object[]>(); | |||||
| } | |||||
| if (values.Length < 3) | |||||
| { | { | ||||
| throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | ||||
| } | } | ||||
| @@ -54,19 +63,37 @@ namespace Tensorflow.Keras.Common | |||||
| { | { | ||||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); | throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); | ||||
| } | } | ||||
| if (values[1] is not int) | |||||
| int nodeIndex; | |||||
| int tensorIndex; | |||||
| if (values[1] is long) | |||||
| { | |||||
| nodeIndex = (int)(long)values[1]; | |||||
| } | |||||
| else if (values[1] is int) | |||||
| { | |||||
| nodeIndex = (int)values[1]; | |||||
| } | |||||
| else | |||||
| { | { | ||||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); | throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); | ||||
| } | } | ||||
| if (values[2] is not int) | |||||
| if (values[2] is long) | |||||
| { | |||||
| tensorIndex = (int)(long)values[2]; | |||||
| } | |||||
| else if (values[1] is int) | |||||
| { | |||||
| tensorIndex = (int)values[2]; | |||||
| } | |||||
| else | |||||
| { | { | ||||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); | throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); | ||||
| } | } | ||||
| return new NodeConfig() | return new NodeConfig() | ||||
| { | { | ||||
| Name = values[0] as string, | Name = values[0] as string, | ||||
| NodeIndex = (int)values[1], | |||||
| TensorIndex = (int)values[2] | |||||
| NodeIndex = nodeIndex, | |||||
| TensorIndex = tensorIndex | |||||
| }; | }; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,8 +1,11 @@ | |||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| @@ -1,9 +1,13 @@ | |||||
| namespace Tensorflow | |||||
| using Newtonsoft.Json; | |||||
| using Tensorflow.Keras.Common; | |||||
| namespace Tensorflow | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. | /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. | ||||
| /// The enum values here are identical to corresponding values in types.proto. | /// The enum values here are identical to corresponding values in types.proto. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonConverter(typeof(CustomizedDTypeJsonConverter))] | |||||
| public enum TF_DataType | public enum TF_DataType | ||||
| { | { | ||||
| DtInvalid = 0, | DtInvalid = 0, | ||||
| @@ -159,7 +159,10 @@ namespace Tensorflow | |||||
| "uint32" => TF_DataType.TF_UINT32, | "uint32" => TF_DataType.TF_UINT32, | ||||
| "int64" => TF_DataType.TF_INT64, | "int64" => TF_DataType.TF_INT64, | ||||
| "uint64" => TF_DataType.TF_UINT64, | "uint64" => TF_DataType.TF_UINT64, | ||||
| "float16" => TF_DataType.TF_BFLOAT16, | |||||
| "float32" => TF_DataType.TF_FLOAT, | |||||
| "single" => TF_DataType.TF_FLOAT, | "single" => TF_DataType.TF_FLOAT, | ||||
| "float64" => TF_DataType.TF_DOUBLE, | |||||
| "double" => TF_DataType.TF_DOUBLE, | "double" => TF_DataType.TF_DOUBLE, | ||||
| "complex" => TF_DataType.TF_COMPLEX128, | "complex" => TF_DataType.TF_COMPLEX128, | ||||
| "string" => TF_DataType.TF_STRING, | "string" => TF_DataType.TF_STRING, | ||||
| @@ -0,0 +1,22 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Functions; | |||||
| namespace Tensorflow.Training.Saving.SavedModel | |||||
| { | |||||
| /// <summary> | |||||
| /// A class wraps a concrete function to handle different distributed contexts. | |||||
| /// </summary> | |||||
| internal class WrapperFunction: ConcreteFunction | |||||
| { | |||||
| public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) | |||||
| { | |||||
| this.forward_backward = concrete_function.forward_backward; | |||||
| this.Outputs = concrete_function.Outputs; | |||||
| this.ReturnType = concrete_function.ReturnType; | |||||
| this.OutputStructure = concrete_function.OutputStructure; | |||||
| this.ArgKeywords = concrete_function.ArgKeywords; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,36 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Functions; | |||||
| using Tensorflow.Util; | |||||
| namespace Tensorflow.Training.Saving.SavedModel | |||||
| { | |||||
| public static class function_deserialization | |||||
| { | |||||
| public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | |||||
| IDictionary<string, ConcreteFunction> concrete_functions) | |||||
| { | |||||
| var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName]; | |||||
| concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); | |||||
| concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; | |||||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | |||||
| concrete_function.AddTograph(); | |||||
| return concrete_function; | |||||
| } | |||||
| private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto) | |||||
| { | |||||
| // TODO(Rinne); revise the implementation. | |||||
| return new FunctionSpec() | |||||
| { | |||||
| Fullargspec = function_spec_proto.Fullargspec, | |||||
| IsMethod = function_spec_proto.IsMethod, | |||||
| InputSignature = function_spec_proto.InputSignature, | |||||
| JitCompile = function_spec_proto.JitCompile | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -12,6 +12,7 @@ using static Tensorflow.Binding; | |||||
| using System.Runtime.CompilerServices; | using System.Runtime.CompilerServices; | ||||
| using Tensorflow.Variables; | using Tensorflow.Variables; | ||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| using Tensorflow.Training.Saving.SavedModel; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| @@ -307,6 +308,11 @@ namespace Tensorflow | |||||
| foreach(var (node_id, proto) in _iter_all_nodes()) | foreach(var (node_id, proto) in _iter_all_nodes()) | ||||
| { | { | ||||
| var node = get(node_id); | var node = get(node_id); | ||||
| if(node is null) | |||||
| { | |||||
| // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||||
| continue; | |||||
| } | |||||
| if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | ||||
| { | { | ||||
| // Restore Trackable serialize- and restore-from-tensor functions. | // Restore Trackable serialize- and restore-from-tensor functions. | ||||
| @@ -376,6 +382,13 @@ namespace Tensorflow | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| // skip the function and concrete function. | |||||
| if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) | |||||
| { | |||||
| nodes[node_id] = null; | |||||
| node_setters[node_id] = null; | |||||
| continue; | |||||
| } | |||||
| var (node, setter) = _recreate(proto, node_id, nodes); | var (node, setter) = _recreate(proto, node_id, nodes); | ||||
| nodes[node_id] = node; | nodes[node_id] = node; | ||||
| node_setters[node_id] = setter; | node_setters[node_id] = setter; | ||||
| @@ -480,6 +493,11 @@ namespace Tensorflow | |||||
| foreach(var refer in proto.Children) | foreach(var refer in proto.Children) | ||||
| { | { | ||||
| if(obj is null) | |||||
| { | |||||
| // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||||
| continue; | |||||
| } | |||||
| setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | ||||
| // skip the process of "__call__" | // skip the process of "__call__" | ||||
| } | } | ||||
| @@ -591,6 +609,13 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| //var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||||
| } | |||||
| // TODO: remove this to a common class. | // TODO: remove this to a common class. | ||||
| public static Action<object, object, object> setattr = (x, y, z) => | public static Action<object, object, object> setattr = (x, y, z) => | ||||
| { | { | ||||
| @@ -24,10 +24,10 @@ namespace Tensorflow.Keras.Engine | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="config"></param> | /// <param name="config"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config) | |||||
| public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null) | |||||
| { | { | ||||
| // Layer instances created during the graph reconstruction process. | // Layer instances created during the graph reconstruction process. | ||||
| var created_layers = new Dictionary<string, ILayer>(); | |||||
| created_layers = created_layers ?? new Dictionary<string, ILayer>(); | |||||
| var node_index_map = new Dictionary<(string, int), int>(); | var node_index_map = new Dictionary<(string, int), int>(); | ||||
| var node_count_by_layer = new Dictionary<ILayer, int>(); | var node_count_by_layer = new Dictionary<ILayer, int>(); | ||||
| var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); | var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); | ||||
| @@ -88,12 +88,7 @@ namespace Tensorflow.Keras.Engine | |||||
| layer = created_layers[layer_name]; | layer = created_layers[layer_name]; | ||||
| else | else | ||||
| { | { | ||||
| layer = layer_data.ClassName switch | |||||
| { | |||||
| "InputLayer" => InputLayer.from_config(layer_data.Config), | |||||
| "Dense" => Dense.from_config(layer_data.Config), | |||||
| _ => throw new NotImplementedException("") | |||||
| }; | |||||
| layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config); | |||||
| created_layers[layer_name] = layer; | created_layers[layer_name] = layer; | ||||
| } | } | ||||
| @@ -12,7 +12,7 @@ public abstract partial class Layer | |||||
| public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | ||||
| public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | |||||
| public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata; | |||||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | ||||
| { | { | ||||
| @@ -96,7 +96,6 @@ namespace Tensorflow.Keras.Engine | |||||
| List<INode> inboundNodes; | List<INode> inboundNodes; | ||||
| public List<INode> InboundNodes => inboundNodes; | public List<INode> InboundNodes => inboundNodes; | ||||
| List<INode> outboundNodes; | List<INode> outboundNodes; | ||||
| public List<INode> OutboundNodes => outboundNodes; | public List<INode> OutboundNodes => outboundNodes; | ||||
| @@ -85,10 +85,5 @@ namespace Tensorflow.Keras.Layers | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| public static Dense from_config(LayerArgs args) | |||||
| { | |||||
| return new Dense(args as DenseArgs); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -102,11 +102,6 @@ namespace Tensorflow.Keras.Layers | |||||
| name: Name); | name: Name); | ||||
| } | } | ||||
| public static InputLayer from_config(LayerArgs args) | |||||
| { | |||||
| return new InputLayer(args as InputLayerArgs); | |||||
| } | |||||
| public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); | public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); | ||||
| } | } | ||||
| } | } | ||||
| @@ -4,6 +4,7 @@ using System.IO; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| using Tensorflow.Keras.Saving.SavedModel; | |||||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | using ThirdParty.Tensorflow.Python.Keras.Protobuf; | ||||
| namespace Tensorflow.Keras.Models | namespace Tensorflow.Keras.Models | ||||
| @@ -13,20 +14,9 @@ namespace Tensorflow.Keras.Models | |||||
| public Functional from_config(ModelConfig config) | public Functional from_config(ModelConfig config) | ||||
| => Functional.from_config(config); | => Functional.from_config(config); | ||||
| public void load_model(string filepath, bool compile = true) | |||||
| public Model load_model(string filepath, bool compile = true, LoadOptions? options = null) | |||||
| { | { | ||||
| var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb")); | |||||
| var saved_mode = SavedModel.Parser.ParseFrom(bytes); | |||||
| var meta_graph_def = saved_mode.MetaGraphs[0]; | |||||
| var object_graph_def = meta_graph_def.ObjectGraphDef; | |||||
| bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb")); | |||||
| var metadata = SavedMetadata.Parser.ParseFrom(bytes); | |||||
| // Recreate layers and metrics using the info stored in the metadata. | |||||
| var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||||
| keras_loader.load_layers(compile: compile); | |||||
| return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -164,11 +164,11 @@ namespace Tensorflow.Keras.Saving | |||||
| { | { | ||||
| if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer") | if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer") | ||||
| { | { | ||||
| layers.Insert(0, InputLayer.from_config(config["layers"][0]["config"].ToObject<InputLayerArgs>())); | |||||
| layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject<InputLayerArgs>())); | |||||
| } | } | ||||
| else if (config["layers"][0]["config"]["batch_input_shape"] is not null) | else if (config["layers"][0]["config"]["batch_input_shape"] is not null) | ||||
| { | { | ||||
| // TODO: implement it | |||||
| // TODO(Rinne): implement it | |||||
| } | } | ||||
| } | } | ||||
| @@ -192,7 +192,8 @@ namespace Tensorflow.Keras.Saving | |||||
| else | else | ||||
| { | { | ||||
| // skip the parameter `created_layers`. | // skip the parameter `created_layers`. | ||||
| var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(config.ToObject<ModelConfig>()); | |||||
| var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config), | |||||
| layers.ToDictionary(x => x.Name, x => x as ILayer)); | |||||
| // skip the `model.__init__` | // skip the `model.__init__` | ||||
| (model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>()); | (model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>()); | ||||
| (model as Functional).connect_ancillary_layers(created_layers); | (model as Functional).connect_ancillary_layers(created_layers); | ||||
| @@ -283,7 +284,6 @@ namespace Tensorflow.Keras.Saving | |||||
| private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | ||||
| { | { | ||||
| metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | ||||
| if (loaded_nodes.ContainsKey(node_id)) | if (loaded_nodes.ContainsKey(node_id)) | ||||
| @@ -95,7 +95,7 @@ public partial class KerasSavedModelUtils | |||||
| BadConsumers = { } | BadConsumers = { } | ||||
| }, | }, | ||||
| Identifier = layer.ObjectIdentifier, | Identifier = layer.ObjectIdentifier, | ||||
| Metadata = layer.TrackingMetadata | |||||
| Metadata = layer.GetTrackingMetadata() | |||||
| }; | }; | ||||
| metadata.Nodes.Add(saved_object); | metadata.Nodes.Add(saved_object); | ||||
| @@ -44,7 +44,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| } | } | ||||
| } | } | ||||
| public static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||||
| private static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||||
| { | { | ||||
| SavedMetadata metadata = new SavedMetadata(); | SavedMetadata metadata = new SavedMetadata(); | ||||
| var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | ||||
| @@ -82,12 +82,12 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||||
| if(model is Model && compile) | if(model is Model && compile) | ||||
| { | { | ||||
| // TODO: implement it. | |||||
| // TODO(Rinne): implement it. | |||||
| } | } | ||||
| if (!tf.Context.executing_eagerly()) | if (!tf.Context.executing_eagerly()) | ||||
| { | { | ||||
| // TODO: implement it. | |||||
| // TODO(Rinne): implement it. | |||||
| } | } | ||||
| return model; | return model; | ||||
| @@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Utils | |||||
| return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | ||||
| } | } | ||||
| public static Layer deserialize_keras_object(string class_name, JObject config) | |||||
| public static Layer deserialize_keras_object(string class_name, JToken config) | |||||
| { | { | ||||
| return class_name switch | return class_name switch | ||||
| { | { | ||||
| @@ -70,6 +70,58 @@ namespace Tensorflow.Keras.Utils | |||||
| }; | }; | ||||
| } | } | ||||
| public static Layer deserialize_keras_object(string class_name, LayerArgs args) | |||||
| { | |||||
| return class_name switch | |||||
| { | |||||
| "Sequential" => new Sequential(args as SequentialArgs), | |||||
| "InputLayer" => new InputLayer(args as InputLayerArgs), | |||||
| "Flatten" => new Flatten(args as FlattenArgs), | |||||
| "ELU" => new ELU(args as ELUArgs), | |||||
| "Dense" => new Dense(args as DenseArgs), | |||||
| "Softmax" => new Softmax(args as SoftmaxArgs), | |||||
| _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + | |||||
| $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") | |||||
| }; | |||||
| } | |||||
| public static LayerArgs? deserialize_layer_args(string class_name, JToken config) | |||||
| { | |||||
| return class_name switch | |||||
| { | |||||
| "Sequential" => config.ToObject<SequentialArgs>(), | |||||
| "InputLayer" => config.ToObject<InputLayerArgs>(), | |||||
| "Flatten" => config.ToObject<FlattenArgs>(), | |||||
| "ELU" => config.ToObject<ELUArgs>(), | |||||
| "Dense" => config.ToObject<DenseArgs>(), | |||||
| "Softmax" => config.ToObject<SoftmaxArgs>(), | |||||
| _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + | |||||
| $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") | |||||
| }; | |||||
| } | |||||
| public static ModelConfig deserialize_model_config(JToken json) | |||||
| { | |||||
| ModelConfig config = new ModelConfig(); | |||||
| config.Name = json["name"].ToObject<string>(); | |||||
| config.Layers = new List<LayerConfig>(); | |||||
| var layersToken = json["layers"]; | |||||
| foreach (var token in layersToken) | |||||
| { | |||||
| var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]); | |||||
| config.Layers.Add(new LayerConfig() | |||||
| { | |||||
| Config = args, | |||||
| Name = token["name"].ToObject<string>(), | |||||
| ClassName = token["class_name"].ToObject<string>(), | |||||
| InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>() | |||||
| }); | |||||
| } | |||||
| config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>(); | |||||
| config.OutputLayers = json["output_layers"].ToObject<List<NodeConfig>>(); | |||||
| return config; | |||||
| } | |||||
| public static string to_snake_case(string name) | public static string to_snake_case(string name) | ||||
| { | { | ||||
| return string.Concat(name.Select((x, i) => | return string.Concat(name.Select((x, i) => | ||||
| @@ -21,17 +21,16 @@ public class SequentialModelLoad | |||||
| [TestMethod] | [TestMethod] | ||||
| public void SimpleModelFromSequential() | public void SimpleModelFromSequential() | ||||
| { | { | ||||
| var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/model.pb"); | |||||
| Debug.Assert(model is Model); | |||||
| var m = model as Model; | |||||
| new SequentialModelSave().SimpleModelFromSequential(); | |||||
| var model = keras.models.load_model(@"./pb_simple_sequential"); | |||||
| m.summary(); | |||||
| model.summary(); | |||||
| m.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
| model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||||
| var data_loader = new MnistModelLoader(); | var data_loader = new MnistModelLoader(); | ||||
| var num_epochs = 1; | var num_epochs = 1; | ||||
| var batch_size = 50; | |||||
| var batch_size = 8; | |||||
| var dataset = data_loader.LoadAsync(new ModelLoadSetting | var dataset = data_loader.LoadAsync(new ModelLoadSetting | ||||
| { | { | ||||
| @@ -40,6 +39,6 @@ public class SequentialModelLoad | |||||
| ValidationSize = 50000, | ValidationSize = 50000, | ||||
| }).Result; | }).Result; | ||||
| m.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,27 +1,21 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow.NumPy; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| using System.Diagnostics; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| using Tensorflow.Keras; | using Tensorflow.Keras; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
| using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
| using Tensorflow.Keras.Metrics; | |||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| using Tensorflow.Operations; | |||||
| using System.Diagnostics; | |||||
| using Tensorflow.NumPy; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.Keras.UnitTest.SaveModel; | namespace TensorFlowNET.Keras.UnitTest.SaveModel; | ||||
| [TestClass] | [TestClass] | ||||
| public class SequentialModelTest | |||||
| public class SequentialModelSave | |||||
| { | { | ||||
| [TestMethod] | [TestMethod] | ||||
| public void SimpleModelFromAutoCompile() | public void SimpleModelFromAutoCompile() | ||||
| @@ -118,7 +112,7 @@ public class SequentialModelTest | |||||
| keras.layers.Softmax(1) | keras.layers.Softmax(1) | ||||
| }); | }); | ||||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); | |||||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||||
| var num_epochs = 1; | var num_epochs = 1; | ||||
| var batch_size = 8; | var batch_size = 8; | ||||