| @@ -62,14 +62,17 @@ namespace Tensorflow.Checkpoint | |||
| return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); | |||
| } | |||
| public unsafe Tensor GetTensor(string name) | |||
| public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) | |||
| { | |||
| Status status = new Status(); | |||
| var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); | |||
| status.Check(true); | |||
| var shape = GetVariableShape(name); | |||
| var dtype = GetVariableDataType(name); | |||
| return new Tensor(c_api.TF_TensorData(tensor), shape, dtype); | |||
| if(dtype == TF_DataType.DtInvalid) | |||
| { | |||
| dtype = GetVariableDataType(name); | |||
| } | |||
| return new Tensor(tensor); | |||
| } | |||
| private void ReadAllShapeAndType() | |||
| @@ -227,7 +227,7 @@ public class TrackableSaver | |||
| { | |||
| dtype_map = reader.VariableToDataTypeMap; | |||
| } | |||
| Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY); | |||
| Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | |||
| Dictionary<Tensor, string> file_prefix_feed_dict; | |||
| Tensor file_prefix_tensor; | |||
| @@ -249,7 +249,14 @@ public class TrackableSaver | |||
| file_prefix_feed_dict = null; | |||
| } | |||
| TrackableObjectGraph object_graph_proto = new(); | |||
| object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); | |||
| if(object_graph_string.ndim > 0) | |||
| { | |||
| object_graph_proto.MergeFrom(object_graph_string.BufferToArray()); | |||
| } | |||
| else | |||
| { | |||
| object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]); | |||
| } | |||
| CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( | |||
| object_graph_proto: object_graph_proto, | |||
| save_path: save_path, | |||
| @@ -13,8 +13,8 @@ namespace Tensorflow.Functions | |||
| /// </summary> | |||
| public class ConcreteFunction: Trackable | |||
| { | |||
| FuncGraph func_graph; | |||
| ForwardBackwardCall forward_backward; | |||
| internal FuncGraph func_graph; | |||
| internal ForwardBackwardCall forward_backward; | |||
| public Tensor[] Inputs => func_graph.Inputs; | |||
| public Tensor[] CapturedInputs => func_graph.external_captures; | |||
| @@ -23,6 +23,8 @@ namespace Tensorflow.Functions | |||
| public Tensor[] Outputs; | |||
| public Type ReturnType; | |||
| public TensorSpec[] OutputStructure; | |||
| public IEnumerable<string> ArgKeywords { get; set; } | |||
| public long NumPositionArgs { get; set; } | |||
| public ConcreteFunction(string name) | |||
| { | |||
| @@ -163,6 +165,15 @@ namespace Tensorflow.Functions | |||
| return flat_outputs; | |||
| } | |||
| public void AddTograph(Graph? g = null) | |||
| { | |||
| if(!tf.Context.executing_eagerly() && g is null) | |||
| { | |||
| g = ops.get_default_graph(); | |||
| } | |||
| // TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||
| } | |||
| ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) | |||
| { | |||
| var functions = new FirstOrderTapeGradientFunctions(func_graph, false); | |||
| @@ -0,0 +1,39 @@ | |||
| using Newtonsoft.Json.Linq; | |||
| using Newtonsoft.Json; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Keras.Common | |||
| { | |||
| public class CustomizedDTypeJsonConverter : JsonConverter | |||
| { | |||
| public override bool CanConvert(Type objectType) | |||
| { | |||
| return objectType == typeof(TF_DataType); | |||
| } | |||
| public override bool CanRead => true; | |||
| public override bool CanWrite => true; | |||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||
| { | |||
| var token = JToken.FromObject(value); | |||
| token.WriteTo(writer); | |||
| } | |||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||
| { | |||
| if (reader.ValueType == typeof(string)) | |||
| { | |||
| var str = (string)serializer.Deserialize(reader, typeof(string)); | |||
| return dtypes.tf_dtype_from_name(str); | |||
| } | |||
| else | |||
| { | |||
| return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -46,7 +46,16 @@ namespace Tensorflow.Keras.Common | |||
| { | |||
| throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||
| } | |||
| if(values.Length != 3) | |||
| if(values.Length == 1) | |||
| { | |||
| var array = values[0] as JArray; | |||
| if(array is null) | |||
| { | |||
| throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||
| } | |||
| values = array.ToObject<object[]>(); | |||
| } | |||
| if (values.Length < 3) | |||
| { | |||
| throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||
| } | |||
| @@ -54,19 +63,37 @@ namespace Tensorflow.Keras.Common | |||
| { | |||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); | |||
| } | |||
| if (values[1] is not int) | |||
| int nodeIndex; | |||
| int tensorIndex; | |||
| if (values[1] is long) | |||
| { | |||
| nodeIndex = (int)(long)values[1]; | |||
| } | |||
| else if (values[1] is int) | |||
| { | |||
| nodeIndex = (int)values[1]; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); | |||
| } | |||
| if (values[2] is not int) | |||
| if (values[2] is long) | |||
| { | |||
| tensorIndex = (int)(long)values[2]; | |||
| } | |||
| else if (values[1] is int) | |||
| { | |||
| tensorIndex = (int)values[2]; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); | |||
| } | |||
| return new NodeConfig() | |||
| { | |||
| Name = values[0] as string, | |||
| NodeIndex = (int)values[1], | |||
| TensorIndex = (int)values[2] | |||
| NodeIndex = nodeIndex, | |||
| TensorIndex = tensorIndex | |||
| }; | |||
| } | |||
| } | |||
| @@ -1,8 +1,11 @@ | |||
| using Newtonsoft.Json; | |||
| using Newtonsoft.Json.Linq; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; | |||
| namespace Tensorflow.Keras.Saving | |||
| { | |||
| @@ -1,9 +1,13 @@ | |||
| namespace Tensorflow | |||
| using Newtonsoft.Json; | |||
| using Tensorflow.Keras.Common; | |||
| namespace Tensorflow | |||
| { | |||
| /// <summary> | |||
| /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. | |||
| /// The enum values here are identical to corresponding values in types.proto. | |||
| /// </summary> | |||
| [JsonConverter(typeof(CustomizedDTypeJsonConverter))] | |||
| public enum TF_DataType | |||
| { | |||
| DtInvalid = 0, | |||
| @@ -159,7 +159,10 @@ namespace Tensorflow | |||
| "uint32" => TF_DataType.TF_UINT32, | |||
| "int64" => TF_DataType.TF_INT64, | |||
| "uint64" => TF_DataType.TF_UINT64, | |||
| "float16" => TF_DataType.TF_BFLOAT16, | |||
| "float32" => TF_DataType.TF_FLOAT, | |||
| "single" => TF_DataType.TF_FLOAT, | |||
| "float64" => TF_DataType.TF_DOUBLE, | |||
| "double" => TF_DataType.TF_DOUBLE, | |||
| "complex" => TF_DataType.TF_COMPLEX128, | |||
| "string" => TF_DataType.TF_STRING, | |||
| @@ -0,0 +1,22 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Functions; | |||
| namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| /// <summary> | |||
| /// A class wraps a concrete function to handle different distributed contexts. | |||
| /// </summary> | |||
| internal class WrapperFunction: ConcreteFunction | |||
| { | |||
| public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph) | |||
| { | |||
| this.forward_backward = concrete_function.forward_backward; | |||
| this.Outputs = concrete_function.Outputs; | |||
| this.ReturnType = concrete_function.ReturnType; | |||
| this.OutputStructure = concrete_function.OutputStructure; | |||
| this.ArgKeywords = concrete_function.ArgKeywords; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,36 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Util; | |||
| namespace Tensorflow.Training.Saving.SavedModel | |||
| { | |||
| public static class function_deserialization | |||
| { | |||
| public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function, | |||
| IDictionary<string, ConcreteFunction> concrete_functions) | |||
| { | |||
| var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName]; | |||
| concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList(); | |||
| concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments; | |||
| var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec); | |||
| concrete_function.AddTograph(); | |||
| return concrete_function; | |||
| } | |||
| private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto) | |||
| { | |||
| // TODO(Rinne); revise the implementation. | |||
| return new FunctionSpec() | |||
| { | |||
| Fullargspec = function_spec_proto.Fullargspec, | |||
| IsMethod = function_spec_proto.IsMethod, | |||
| InputSignature = function_spec_proto.InputSignature, | |||
| JitCompile = function_spec_proto.JitCompile | |||
| }; | |||
| } | |||
| } | |||
| } | |||
| @@ -12,6 +12,7 @@ using static Tensorflow.Binding; | |||
| using System.Runtime.CompilerServices; | |||
| using Tensorflow.Variables; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -307,6 +308,11 @@ namespace Tensorflow | |||
| foreach(var (node_id, proto) in _iter_all_nodes()) | |||
| { | |||
| var node = get(node_id); | |||
| if(node is null) | |||
| { | |||
| // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||
| continue; | |||
| } | |||
| if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) | |||
| { | |||
| // Restore Trackable serialize- and restore-from-tensor functions. | |||
| @@ -376,6 +382,13 @@ namespace Tensorflow | |||
| } | |||
| else | |||
| { | |||
| // skip the function and concrete function. | |||
| if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function) | |||
| { | |||
| nodes[node_id] = null; | |||
| node_setters[node_id] = null; | |||
| continue; | |||
| } | |||
| var (node, setter) = _recreate(proto, node_id, nodes); | |||
| nodes[node_id] = node; | |||
| node_setters[node_id] = setter; | |||
| @@ -480,6 +493,11 @@ namespace Tensorflow | |||
| foreach(var refer in proto.Children) | |||
| { | |||
| if(obj is null) | |||
| { | |||
| // skip it because now we skip the restoration of `Function` and `ConcreteFunction`. | |||
| continue; | |||
| } | |||
| setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); | |||
| // skip the process of "__call__" | |||
| } | |||
| @@ -591,6 +609,13 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | |||
| { | |||
| throw new NotImplementedException(); | |||
| //var fn = function_deserialization.setup_bare_concrete_function(proto, ) | |||
| } | |||
| // TODO: remove this to a common class. | |||
| public static Action<object, object, object> setattr = (x, y, z) => | |||
| { | |||
| @@ -24,10 +24,10 @@ namespace Tensorflow.Keras.Engine | |||
| /// </summary> | |||
| /// <param name="config"></param> | |||
| /// <returns></returns> | |||
| public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config) | |||
| public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null) | |||
| { | |||
| // Layer instances created during the graph reconstruction process. | |||
| var created_layers = new Dictionary<string, ILayer>(); | |||
| created_layers = created_layers ?? new Dictionary<string, ILayer>(); | |||
| var node_index_map = new Dictionary<(string, int), int>(); | |||
| var node_count_by_layer = new Dictionary<ILayer, int>(); | |||
| var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); | |||
| @@ -88,12 +88,7 @@ namespace Tensorflow.Keras.Engine | |||
| layer = created_layers[layer_name]; | |||
| else | |||
| { | |||
| layer = layer_data.ClassName switch | |||
| { | |||
| "InputLayer" => InputLayer.from_config(layer_data.Config), | |||
| "Dense" => Dense.from_config(layer_data.Config), | |||
| _ => throw new NotImplementedException("") | |||
| }; | |||
| layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config); | |||
| created_layers[layer_name] = layer; | |||
| } | |||
| @@ -12,7 +12,7 @@ public abstract partial class Layer | |||
| public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; | |||
| public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata; | |||
| public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata; | |||
| public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) | |||
| { | |||
| @@ -96,7 +96,6 @@ namespace Tensorflow.Keras.Engine | |||
| List<INode> inboundNodes; | |||
| public List<INode> InboundNodes => inboundNodes; | |||
| List<INode> outboundNodes; | |||
| public List<INode> OutboundNodes => outboundNodes; | |||
| @@ -85,10 +85,5 @@ namespace Tensorflow.Keras.Layers | |||
| return outputs; | |||
| } | |||
| public static Dense from_config(LayerArgs args) | |||
| { | |||
| return new Dense(args as DenseArgs); | |||
| } | |||
| } | |||
| } | |||
| @@ -102,11 +102,6 @@ namespace Tensorflow.Keras.Layers | |||
| name: Name); | |||
| } | |||
| public static InputLayer from_config(LayerArgs args) | |||
| { | |||
| return new InputLayer(args as InputLayerArgs); | |||
| } | |||
| public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); | |||
| } | |||
| } | |||
| @@ -4,6 +4,7 @@ using System.IO; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Keras.Saving.SavedModel; | |||
| using ThirdParty.Tensorflow.Python.Keras.Protobuf; | |||
| namespace Tensorflow.Keras.Models | |||
| @@ -13,20 +14,9 @@ namespace Tensorflow.Keras.Models | |||
| public Functional from_config(ModelConfig config) | |||
| => Functional.from_config(config); | |||
| public void load_model(string filepath, bool compile = true) | |||
| public Model load_model(string filepath, bool compile = true, LoadOptions? options = null) | |||
| { | |||
| var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb")); | |||
| var saved_mode = SavedModel.Parser.ParseFrom(bytes); | |||
| var meta_graph_def = saved_mode.MetaGraphs[0]; | |||
| var object_graph_def = meta_graph_def.ObjectGraphDef; | |||
| bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb")); | |||
| var metadata = SavedMetadata.Parser.ParseFrom(bytes); | |||
| // Recreate layers and metrics using the info stored in the metadata. | |||
| var keras_loader = new KerasObjectLoader(metadata, object_graph_def); | |||
| keras_loader.load_layers(compile: compile); | |||
| return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model; | |||
| } | |||
| } | |||
| } | |||
| @@ -164,11 +164,11 @@ namespace Tensorflow.Keras.Saving | |||
| { | |||
| if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer") | |||
| { | |||
| layers.Insert(0, InputLayer.from_config(config["layers"][0]["config"].ToObject<InputLayerArgs>())); | |||
| layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject<InputLayerArgs>())); | |||
| } | |||
| else if (config["layers"][0]["config"]["batch_input_shape"] is not null) | |||
| { | |||
| // TODO: implement it | |||
| // TODO(Rinne): implement it | |||
| } | |||
| } | |||
| @@ -192,7 +192,8 @@ namespace Tensorflow.Keras.Saving | |||
| else | |||
| { | |||
| // skip the parameter `created_layers`. | |||
| var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(config.ToObject<ModelConfig>()); | |||
| var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config), | |||
| layers.ToDictionary(x => x.Name, x => x as ILayer)); | |||
| // skip the `model.__init__` | |||
| (model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>()); | |||
| (model as Functional).connect_ancillary_layers(created_layers); | |||
| @@ -283,7 +284,6 @@ namespace Tensorflow.Keras.Saving | |||
| private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) | |||
| { | |||
| metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1"); | |||
| var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); | |||
| if (loaded_nodes.ContainsKey(node_id)) | |||
| @@ -95,7 +95,7 @@ public partial class KerasSavedModelUtils | |||
| BadConsumers = { } | |||
| }, | |||
| Identifier = layer.ObjectIdentifier, | |||
| Metadata = layer.TrackingMetadata | |||
| Metadata = layer.GetTrackingMetadata() | |||
| }; | |||
| metadata.Nodes.Add(saved_object); | |||
| @@ -44,7 +44,7 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| } | |||
| } | |||
| public static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||
| private static Trackable load(string path, bool compile = true, LoadOptions? options = null) | |||
| { | |||
| SavedMetadata metadata = new SavedMetadata(); | |||
| var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; | |||
| @@ -82,12 +82,12 @@ namespace Tensorflow.Keras.Saving.SavedModel | |||
| if(model is Model && compile) | |||
| { | |||
| // TODO: implement it. | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| if (!tf.Context.executing_eagerly()) | |||
| { | |||
| // TODO: implement it. | |||
| // TODO(Rinne): implement it. | |||
| } | |||
| return model; | |||
| @@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Utils | |||
| return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | |||
| } | |||
| public static Layer deserialize_keras_object(string class_name, JObject config) | |||
| public static Layer deserialize_keras_object(string class_name, JToken config) | |||
| { | |||
| return class_name switch | |||
| { | |||
| @@ -70,6 +70,58 @@ namespace Tensorflow.Keras.Utils | |||
| }; | |||
| } | |||
| public static Layer deserialize_keras_object(string class_name, LayerArgs args) | |||
| { | |||
| return class_name switch | |||
| { | |||
| "Sequential" => new Sequential(args as SequentialArgs), | |||
| "InputLayer" => new InputLayer(args as InputLayerArgs), | |||
| "Flatten" => new Flatten(args as FlattenArgs), | |||
| "ELU" => new ELU(args as ELUArgs), | |||
| "Dense" => new Dense(args as DenseArgs), | |||
| "Softmax" => new Softmax(args as SoftmaxArgs), | |||
| _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + | |||
| $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") | |||
| }; | |||
| } | |||
| public static LayerArgs? deserialize_layer_args(string class_name, JToken config) | |||
| { | |||
| return class_name switch | |||
| { | |||
| "Sequential" => config.ToObject<SequentialArgs>(), | |||
| "InputLayer" => config.ToObject<InputLayerArgs>(), | |||
| "Flatten" => config.ToObject<FlattenArgs>(), | |||
| "ELU" => config.ToObject<ELUArgs>(), | |||
| "Dense" => config.ToObject<DenseArgs>(), | |||
| "Softmax" => config.ToObject<SoftmaxArgs>(), | |||
| _ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " + | |||
| $"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues") | |||
| }; | |||
| } | |||
| public static ModelConfig deserialize_model_config(JToken json) | |||
| { | |||
| ModelConfig config = new ModelConfig(); | |||
| config.Name = json["name"].ToObject<string>(); | |||
| config.Layers = new List<LayerConfig>(); | |||
| var layersToken = json["layers"]; | |||
| foreach (var token in layersToken) | |||
| { | |||
| var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]); | |||
| config.Layers.Add(new LayerConfig() | |||
| { | |||
| Config = args, | |||
| Name = token["name"].ToObject<string>(), | |||
| ClassName = token["class_name"].ToObject<string>(), | |||
| InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>() | |||
| }); | |||
| } | |||
| config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>(); | |||
| config.OutputLayers = json["output_layers"].ToObject<List<NodeConfig>>(); | |||
| return config; | |||
| } | |||
| public static string to_snake_case(string name) | |||
| { | |||
| return string.Concat(name.Select((x, i) => | |||
| @@ -21,17 +21,16 @@ public class SequentialModelLoad | |||
| [TestMethod] | |||
| public void SimpleModelFromSequential() | |||
| { | |||
| var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/model.pb"); | |||
| Debug.Assert(model is Model); | |||
| var m = model as Model; | |||
| new SequentialModelSave().SimpleModelFromSequential(); | |||
| var model = keras.models.load_model(@"./pb_simple_sequential"); | |||
| m.summary(); | |||
| model.summary(); | |||
| m.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
| model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); | |||
| var data_loader = new MnistModelLoader(); | |||
| var num_epochs = 1; | |||
| var batch_size = 50; | |||
| var batch_size = 8; | |||
| var dataset = data_loader.LoadAsync(new ModelLoadSetting | |||
| { | |||
| @@ -40,6 +39,6 @@ public class SequentialModelLoad | |||
| ValidationSize = 50000, | |||
| }).Result; | |||
| m.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | |||
| } | |||
| } | |||
| @@ -1,27 +1,21 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow.NumPy; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using System.Diagnostics; | |||
| using Tensorflow; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| using Tensorflow.Keras; | |||
| using Tensorflow.Keras.ArgsDefinition; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Layers; | |||
| using Tensorflow.Keras.Losses; | |||
| using Tensorflow.Keras.Metrics; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Operations; | |||
| using System.Diagnostics; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace TensorFlowNET.Keras.UnitTest.SaveModel; | |||
| [TestClass] | |||
| public class SequentialModelTest | |||
| public class SequentialModelSave | |||
| { | |||
| [TestMethod] | |||
| public void SimpleModelFromAutoCompile() | |||
| @@ -118,7 +112,7 @@ public class SequentialModelTest | |||
| keras.layers.Softmax(1) | |||
| }); | |||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); | |||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" }); | |||
| var num_epochs = 1; | |||
| var batch_size = 8; | |||