Browse Source

Add support for loading models from python.

pull/989/head
Yaohui Liu 2 years ago
parent
commit
3972989114
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
23 changed files with 273 additions and 74 deletions
  1. +6
    -3
      src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
  2. +9
    -2
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  3. +13
    -2
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  4. +39
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs
  5. +32
    -5
      src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
  6. +3
    -0
      src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
  7. +5
    -1
      src/TensorFlowNET.Core/Tensors/TF_DataType.cs
  8. +3
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  9. +22
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs
  10. +36
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  11. +25
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  12. +3
    -8
      src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
  13. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
  14. +0
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  15. +0
    -5
      src/TensorFlowNET.Keras/Layers/Core/Dense.cs
  16. +0
    -5
      src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
  17. +3
    -13
      src/TensorFlowNET.Keras/Models/ModelsApi.cs
  18. +4
    -4
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  19. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  20. +3
    -3
      src/TensorFlowNET.Keras/Saving/SavedModel/load.cs
  21. +53
    -1
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  22. +6
    -7
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
  23. +6
    -12
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs

+ 6
- 3
src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs View File

@@ -62,14 +62,17 @@ namespace Tensorflow.Checkpoint
return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name);
}

public unsafe Tensor GetTensor(string name)
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{
Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle);
status.Check(true);
var shape = GetVariableShape(name);
var dtype = GetVariableDataType(name);
return new Tensor(c_api.TF_TensorData(tensor), shape, dtype);
if(dtype == TF_DataType.DtInvalid)
{
dtype = GetVariableDataType(name);
}
return new Tensor(tensor);
}

private void ReadAllShapeAndType()


+ 9
- 2
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -227,7 +227,7 @@ public class TrackableSaver
{
dtype_map = reader.VariableToDataTypeMap;
}
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY);
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING);

Dictionary<Tensor, string> file_prefix_feed_dict;
Tensor file_prefix_tensor;
@@ -249,7 +249,14 @@ public class TrackableSaver
file_prefix_feed_dict = null;
}
TrackableObjectGraph object_graph_proto = new();
object_graph_proto.MergeFrom(object_graph_string.BufferToArray());
if(object_graph_string.ndim > 0)
{
object_graph_proto.MergeFrom(object_graph_string.BufferToArray());
}
else
{
object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]);
}
CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator(
object_graph_proto: object_graph_proto,
save_path: save_path,


+ 13
- 2
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -13,8 +13,8 @@ namespace Tensorflow.Functions
/// </summary>
public class ConcreteFunction: Trackable
{
FuncGraph func_graph;
ForwardBackwardCall forward_backward;
internal FuncGraph func_graph;
internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures;

@@ -23,6 +23,8 @@ namespace Tensorflow.Functions
public Tensor[] Outputs;
public Type ReturnType;
public TensorSpec[] OutputStructure;
public IEnumerable<string> ArgKeywords { get; set; }
public long NumPositionArgs { get; set; }

public ConcreteFunction(string name)
{
@@ -163,6 +165,15 @@ namespace Tensorflow.Functions
return flat_outputs;
}

public void AddTograph(Graph? g = null)
{
if(!tf.Context.executing_eagerly() && g is null)
{
g = ops.get_default_graph();
}
// TODO(Rinne); complete it with `_delayed_rewrite_functions`.
}

ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
{
var functions = new FirstOrderTapeGradientFunctions(func_graph, false);


+ 39
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs View File

@@ -0,0 +1,39 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedDTypeJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(TF_DataType);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
var token = JToken.FromObject(value);
token.WriteTo(writer);
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
if (reader.ValueType == typeof(string))
{
var str = (string)serializer.Deserialize(reader, typeof(string));
return dtypes.tf_dtype_from_name(str);
}
else
{
return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType));
}
}
}
}

+ 32
- 5
src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs View File

@@ -46,7 +46,16 @@ namespace Tensorflow.Keras.Common
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}
if(values.Length != 3)
if(values.Length == 1)
{
var array = values[0] as JArray;
if(array is null)
{
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
}
values = array.ToObject<object[]>();
}
if (values.Length < 3)
{
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
}
@@ -54,19 +63,37 @@ namespace Tensorflow.Keras.Common
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`");
}
if (values[1] is not int)
int nodeIndex;
int tensorIndex;
if (values[1] is long)
{
nodeIndex = (int)(long)values[1];
}
else if (values[1] is int)
{
nodeIndex = (int)values[1];
}
else
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`");
}
if (values[2] is not int)
if (values[2] is long)
{
tensorIndex = (int)(long)values[2];
}
else if (values[1] is int)
{
tensorIndex = (int)values[2];
}
else
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`");
}
return new NodeConfig()
{
Name = values[0] as string,
NodeIndex = (int)values[1],
TensorIndex = (int)values[2]
NodeIndex = nodeIndex,
TensorIndex = tensorIndex
};
}
}


+ 3
- 0
src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs View File

@@ -1,8 +1,11 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;

namespace Tensorflow.Keras.Saving
{


+ 5
- 1
src/TensorFlowNET.Core/Tensors/TF_DataType.cs View File

@@ -1,9 +1,13 @@
namespace Tensorflow
using Newtonsoft.Json;
using Tensorflow.Keras.Common;

namespace Tensorflow
{
/// <summary>
/// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
/// The enum values here are identical to corresponding values in types.proto.
/// </summary>
[JsonConverter(typeof(CustomizedDTypeJsonConverter))]
public enum TF_DataType
{
DtInvalid = 0,


+ 3
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -159,7 +159,10 @@ namespace Tensorflow
"uint32" => TF_DataType.TF_UINT32,
"int64" => TF_DataType.TF_INT64,
"uint64" => TF_DataType.TF_UINT64,
"float16" => TF_DataType.TF_BFLOAT16,
"float32" => TF_DataType.TF_FLOAT,
"single" => TF_DataType.TF_FLOAT,
"float64" => TF_DataType.TF_DOUBLE,
"double" => TF_DataType.TF_DOUBLE,
"complex" => TF_DataType.TF_COMPLEX128,
"string" => TF_DataType.TF_STRING,


+ 22
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Functions;

namespace Tensorflow.Training.Saving.SavedModel
{
/// <summary>
/// A class wraps a concrete function to handle different distributed contexts.
/// </summary>
internal class WrapperFunction: ConcreteFunction
{
public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph)
{
this.forward_backward = concrete_function.forward_backward;
this.Outputs = concrete_function.Outputs;
this.ReturnType = concrete_function.ReturnType;
this.OutputStructure = concrete_function.OutputStructure;
this.ArgKeywords = concrete_function.ArgKeywords;
}
}
}

+ 36
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -0,0 +1,36 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Functions;
using Tensorflow.Util;

namespace Tensorflow.Training.Saving.SavedModel
{
public static class function_deserialization
{
public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function,
IDictionary<string, ConcreteFunction> concrete_functions)
{
var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName];
concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList();
concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments;

var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec);
concrete_function.AddTograph();
return concrete_function;
}

private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto)
{
// TODO(Rinne); revise the implementation.
return new FunctionSpec()
{
Fullargspec = function_spec_proto.Fullargspec,
IsMethod = function_spec_proto.IsMethod,
InputSignature = function_spec_proto.InputSignature,
JitCompile = function_spec_proto.JitCompile
};
}
}
}

+ 25
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -12,6 +12,7 @@ using static Tensorflow.Binding;
using System.Runtime.CompilerServices;
using Tensorflow.Variables;
using Tensorflow.Functions;
using Tensorflow.Training.Saving.SavedModel;

namespace Tensorflow
{
@@ -307,6 +308,11 @@ namespace Tensorflow
foreach(var (node_id, proto) in _iter_all_nodes())
{
var node = get(node_id);
if(node is null)
{
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`.
continue;
}
if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
{
// Restore Trackable serialize- and restore-from-tensor functions.
@@ -376,6 +382,13 @@ namespace Tensorflow
}
else
{
// skip the function and concrete function.
if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function)
{
nodes[node_id] = null;
node_setters[node_id] = null;
continue;
}
var (node, setter) = _recreate(proto, node_id, nodes);
nodes[node_id] = node;
node_setters[node_id] = setter;
@@ -480,6 +493,11 @@ namespace Tensorflow

foreach(var refer in proto.Children)
{
if(obj is null)
{
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`.
continue;
}
setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]);
// skip the process of "__call__"
}
@@ -591,6 +609,13 @@ namespace Tensorflow
}
}

private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,
Dictionary<Maybe<string, int>, Trackable> dependencies)
{
throw new NotImplementedException();
//var fn = function_deserialization.setup_bare_concrete_function(proto, )
}

// TODO: remove this to a common class.
public static Action<object, object, object> setattr = (x, y, z) =>
{


+ 3
- 8
src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs View File

@@ -24,10 +24,10 @@ namespace Tensorflow.Keras.Engine
/// </summary>
/// <param name="config"></param>
/// <returns></returns>
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config)
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null)
{
// Layer instances created during the graph reconstruction process.
var created_layers = new Dictionary<string, ILayer>();
created_layers = created_layers ?? new Dictionary<string, ILayer>();
var node_index_map = new Dictionary<(string, int), int>();
var node_count_by_layer = new Dictionary<ILayer, int>();
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
@@ -88,12 +88,7 @@ namespace Tensorflow.Keras.Engine
layer = created_layers[layer_name];
else
{
layer = layer_data.ClassName switch
{
"InputLayer" => InputLayer.from_config(layer_data.Config),
"Dense" => Dense.from_config(layer_data.Config),
_ => throw new NotImplementedException("")
};
layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config);

created_layers[layer_name] = layer;
}


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs View File

@@ -12,7 +12,7 @@ public abstract partial class Layer

public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier;

public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata;
public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata;

public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{


+ 0
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -96,7 +96,6 @@ namespace Tensorflow.Keras.Engine

List<INode> inboundNodes;
public List<INode> InboundNodes => inboundNodes;

List<INode> outboundNodes;
public List<INode> OutboundNodes => outboundNodes;



+ 0
- 5
src/TensorFlowNET.Keras/Layers/Core/Dense.cs View File

@@ -85,10 +85,5 @@ namespace Tensorflow.Keras.Layers

return outputs;
}

public static Dense from_config(LayerArgs args)
{
return new Dense(args as DenseArgs);
}
}
}

+ 0
- 5
src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs View File

@@ -102,11 +102,6 @@ namespace Tensorflow.Keras.Layers
name: Name);
}

public static InputLayer from_config(LayerArgs args)
{
return new InputLayer(args as InputLayerArgs);
}

public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this);
}
}

+ 3
- 13
src/TensorFlowNET.Keras/Models/ModelsApi.cs View File

@@ -4,6 +4,7 @@ using System.IO;
using System.Text;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel;
using ThirdParty.Tensorflow.Python.Keras.Protobuf;

namespace Tensorflow.Keras.Models
@@ -13,20 +14,9 @@ namespace Tensorflow.Keras.Models
public Functional from_config(ModelConfig config)
=> Functional.from_config(config);

public void load_model(string filepath, bool compile = true)
public Model load_model(string filepath, bool compile = true, LoadOptions? options = null)
{
var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb"));
var saved_mode = SavedModel.Parser.ParseFrom(bytes);
var meta_graph_def = saved_mode.MetaGraphs[0];
var object_graph_def = meta_graph_def.ObjectGraphDef;

bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb"));
var metadata = SavedMetadata.Parser.ParseFrom(bytes);

// Recreate layers and metrics using the info stored in the metadata.
var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
keras_loader.load_layers(compile: compile);
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
}
}
}

+ 4
- 4
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -164,11 +164,11 @@ namespace Tensorflow.Keras.Saving
{
if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer")
{
layers.Insert(0, InputLayer.from_config(config["layers"][0]["config"].ToObject<InputLayerArgs>()));
layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject<InputLayerArgs>()));
}
else if (config["layers"][0]["config"]["batch_input_shape"] is not null)
{
// TODO: implement it
// TODO(Rinne): implement it
}
}
@@ -192,7 +192,8 @@ namespace Tensorflow.Keras.Saving
else
{
// skip the parameter `created_layers`.
var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(config.ToObject<ModelConfig>());
var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config),
layers.ToDictionary(x => x.Name, x => x as ILayer));
// skip the `model.__init__`
(model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>());
(model as Functional).connect_ancillary_layers(created_layers);
@@ -283,7 +284,6 @@ namespace Tensorflow.Keras.Saving

private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json)
{
metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1");
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);

if (loaded_nodes.ContainsKey(node_id))


+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -95,7 +95,7 @@ public partial class KerasSavedModelUtils
BadConsumers = { }
},
Identifier = layer.ObjectIdentifier,
Metadata = layer.TrackingMetadata
Metadata = layer.GetTrackingMetadata()
};

metadata.Nodes.Add(saved_object);


+ 3
- 3
src/TensorFlowNET.Keras/Saving/SavedModel/load.cs View File

@@ -44,7 +44,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
}
}

public static Trackable load(string path, bool compile = true, LoadOptions? options = null)
private static Trackable load(string path, bool compile = true, LoadOptions? options = null)
{
SavedMetadata metadata = new SavedMetadata();
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
@@ -82,12 +82,12 @@ namespace Tensorflow.Keras.Saving.SavedModel

if(model is Model && compile)
{
// TODO: implement it.
// TODO(Rinne): implement it.
}

if (!tf.Context.executing_eagerly())
{
// TODO: implement it.
// TODO(Rinne): implement it.
}

return model;


+ 53
- 1
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Utils
return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance);
}

public static Layer deserialize_keras_object(string class_name, JObject config)
public static Layer deserialize_keras_object(string class_name, JToken config)
{
return class_name switch
{
@@ -70,6 +70,58 @@ namespace Tensorflow.Keras.Utils
};
}

public static Layer deserialize_keras_object(string class_name, LayerArgs args)
{
return class_name switch
{
"Sequential" => new Sequential(args as SequentialArgs),
"InputLayer" => new InputLayer(args as InputLayerArgs),
"Flatten" => new Flatten(args as FlattenArgs),
"ELU" => new ELU(args as ELUArgs),
"Dense" => new Dense(args as DenseArgs),
"Softmax" => new Softmax(args as SoftmaxArgs),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
}

public static LayerArgs? deserialize_layer_args(string class_name, JToken config)
{
return class_name switch
{
"Sequential" => config.ToObject<SequentialArgs>(),
"InputLayer" => config.ToObject<InputLayerArgs>(),
"Flatten" => config.ToObject<FlattenArgs>(),
"ELU" => config.ToObject<ELUArgs>(),
"Dense" => config.ToObject<DenseArgs>(),
"Softmax" => config.ToObject<SoftmaxArgs>(),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
}

public static ModelConfig deserialize_model_config(JToken json)
{
ModelConfig config = new ModelConfig();
config.Name = json["name"].ToObject<string>();
config.Layers = new List<LayerConfig>();
var layersToken = json["layers"];
foreach (var token in layersToken)
{
var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]);
config.Layers.Add(new LayerConfig()
{
Config = args,
Name = token["name"].ToObject<string>(),
ClassName = token["class_name"].ToObject<string>(),
InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>()
});
}
config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>();
config.OutputLayers = json["output_layers"].ToObject<List<NodeConfig>>();
return config;
}

public static string to_snake_case(string name)
{
return string.Concat(name.Select((x, i) =>


+ 6
- 7
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -21,17 +21,16 @@ public class SequentialModelLoad
[TestMethod]
public void SimpleModelFromSequential()
{
var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/model.pb");
Debug.Assert(model is Model);
var m = model as Model;
new SequentialModelSave().SimpleModelFromSequential();
var model = keras.models.load_model(@"./pb_simple_sequential");

m.summary();
model.summary();

m.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });

var data_loader = new MnistModelLoader();
var num_epochs = 1;
var batch_size = 50;
var batch_size = 8;

var dataset = data_loader.LoadAsync(new ModelLoadSetting
{
@@ -40,6 +39,6 @@ public class SequentialModelLoad
ValidationSize = 50000,
}).Result;

m.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
}
}

+ 6
- 12
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs View File

@@ -1,27 +1,21 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Diagnostics;
using Tensorflow;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Operations;
using System.Diagnostics;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace TensorFlowNET.Keras.UnitTest.SaveModel;

[TestClass]
public class SequentialModelTest
public class SequentialModelSave
{
[TestMethod]
public void SimpleModelFromAutoCompile()
@@ -118,7 +112,7 @@ public class SequentialModelTest
keras.layers.Softmax(1)
});

model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" });
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });

var num_epochs = 1;
var batch_size = 8;


Loading…
Cancel
Save