Browse Source

Add support for loading models from python.

pull/989/head
Yaohui Liu 2 years ago
parent
commit
3972989114
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
23 changed files with 273 additions and 74 deletions
  1. +6
    -3
      src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
  2. +9
    -2
      src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
  3. +13
    -2
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  4. +39
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs
  5. +32
    -5
      src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
  6. +3
    -0
      src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
  7. +5
    -1
      src/TensorFlowNET.Core/Tensors/TF_DataType.cs
  8. +3
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  9. +22
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs
  10. +36
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  11. +25
    -0
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  12. +3
    -8
      src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs
  13. +1
    -1
      src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs
  14. +0
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  15. +0
    -5
      src/TensorFlowNET.Keras/Layers/Core/Dense.cs
  16. +0
    -5
      src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs
  17. +3
    -13
      src/TensorFlowNET.Keras/Models/ModelsApi.cs
  18. +4
    -4
      src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs
  19. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  20. +3
    -3
      src/TensorFlowNET.Keras/Saving/SavedModel/load.cs
  21. +53
    -1
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  22. +6
    -7
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs
  23. +6
    -12
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs

+ 6
- 3
src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs View File

@@ -62,14 +62,17 @@ namespace Tensorflow.Checkpoint
return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name);
} }


public unsafe Tensor GetTensor(string name)
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{ {
Status status = new Status(); Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle);
status.Check(true); status.Check(true);
var shape = GetVariableShape(name); var shape = GetVariableShape(name);
var dtype = GetVariableDataType(name);
return new Tensor(c_api.TF_TensorData(tensor), shape, dtype);
if(dtype == TF_DataType.DtInvalid)
{
dtype = GetVariableDataType(name);
}
return new Tensor(tensor);
} }


private void ReadAllShapeAndType() private void ReadAllShapeAndType()


+ 9
- 2
src/TensorFlowNET.Core/Checkpoint/checkpoint.cs View File

@@ -227,7 +227,7 @@ public class TrackableSaver
{ {
dtype_map = reader.VariableToDataTypeMap; dtype_map = reader.VariableToDataTypeMap;
} }
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY);
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING);


Dictionary<Tensor, string> file_prefix_feed_dict; Dictionary<Tensor, string> file_prefix_feed_dict;
Tensor file_prefix_tensor; Tensor file_prefix_tensor;
@@ -249,7 +249,14 @@ public class TrackableSaver
file_prefix_feed_dict = null; file_prefix_feed_dict = null;
} }
TrackableObjectGraph object_graph_proto = new(); TrackableObjectGraph object_graph_proto = new();
object_graph_proto.MergeFrom(object_graph_string.BufferToArray());
if(object_graph_string.ndim > 0)
{
object_graph_proto.MergeFrom(object_graph_string.BufferToArray());
}
else
{
object_graph_proto.MergeFrom(object_graph_string.StringBytes()[0]);
}
CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator( CheckpointRestoreCoordinator checkpoint = new CheckpointRestoreCoordinator(
object_graph_proto: object_graph_proto, object_graph_proto: object_graph_proto,
save_path: save_path, save_path: save_path,


+ 13
- 2
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -13,8 +13,8 @@ namespace Tensorflow.Functions
/// </summary> /// </summary>
public class ConcreteFunction: Trackable public class ConcreteFunction: Trackable
{ {
FuncGraph func_graph;
ForwardBackwardCall forward_backward;
internal FuncGraph func_graph;
internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs; public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures; public Tensor[] CapturedInputs => func_graph.external_captures;


@@ -23,6 +23,8 @@ namespace Tensorflow.Functions
public Tensor[] Outputs; public Tensor[] Outputs;
public Type ReturnType; public Type ReturnType;
public TensorSpec[] OutputStructure; public TensorSpec[] OutputStructure;
public IEnumerable<string> ArgKeywords { get; set; }
public long NumPositionArgs { get; set; }


public ConcreteFunction(string name) public ConcreteFunction(string name)
{ {
@@ -163,6 +165,15 @@ namespace Tensorflow.Functions
return flat_outputs; return flat_outputs;
} }


public void AddTograph(Graph? g = null)
{
if(!tf.Context.executing_eagerly() && g is null)
{
g = ops.get_default_graph();
}
// TODO(Rinne); complete it with `_delayed_rewrite_functions`.
}

ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly) ForwardBackwardCall SelectForwardAndBackwardFunctions(Tensors args, int possible_gradient_type, bool executing_eagerly)
{ {
var functions = new FirstOrderTapeGradientFunctions(func_graph, false); var functions = new FirstOrderTapeGradientFunctions(func_graph, false);


+ 39
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs View File

@@ -0,0 +1,39 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedDTypeJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(TF_DataType);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
var token = JToken.FromObject(value);
token.WriteTo(writer);
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
if (reader.ValueType == typeof(string))
{
var str = (string)serializer.Deserialize(reader, typeof(string));
return dtypes.tf_dtype_from_name(str);
}
else
{
return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType));
}
}
}
}

+ 32
- 5
src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs View File

@@ -46,7 +46,16 @@ namespace Tensorflow.Keras.Common
{ {
throw new ValueError("Cannot deserialize 'null' to `Shape`."); throw new ValueError("Cannot deserialize 'null' to `Shape`.");
} }
if(values.Length != 3)
if(values.Length == 1)
{
var array = values[0] as JArray;
if(array is null)
{
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
}
values = array.ToObject<object[]>();
}
if (values.Length < 3)
{ {
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
} }
@@ -54,19 +63,37 @@ namespace Tensorflow.Keras.Common
{ {
throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`");
} }
if (values[1] is not int)
int nodeIndex;
int tensorIndex;
if (values[1] is long)
{
nodeIndex = (int)(long)values[1];
}
else if (values[1] is int)
{
nodeIndex = (int)values[1];
}
else
{ {
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`");
} }
if (values[2] is not int)
if (values[2] is long)
{
tensorIndex = (int)(long)values[2];
}
else if (values[1] is int)
{
tensorIndex = (int)values[2];
}
else
{ {
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`");
} }
return new NodeConfig() return new NodeConfig()
{ {
Name = values[0] as string, Name = values[0] as string,
NodeIndex = (int)values[1],
TensorIndex = (int)values[2]
NodeIndex = nodeIndex,
TensorIndex = tensorIndex
}; };
} }
} }


+ 3
- 0
src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs View File

@@ -1,8 +1,11 @@
using Newtonsoft.Json; using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;


namespace Tensorflow.Keras.Saving namespace Tensorflow.Keras.Saving
{ {


+ 5
- 1
src/TensorFlowNET.Core/Tensors/TF_DataType.cs View File

@@ -1,9 +1,13 @@
namespace Tensorflow
using Newtonsoft.Json;
using Tensorflow.Keras.Common;

namespace Tensorflow
{ {
/// <summary> /// <summary>
/// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor. /// TF_DataType holds the type for a scalar value. E.g., one slot in a tensor.
/// The enum values here are identical to corresponding values in types.proto. /// The enum values here are identical to corresponding values in types.proto.
/// </summary> /// </summary>
[JsonConverter(typeof(CustomizedDTypeJsonConverter))]
public enum TF_DataType public enum TF_DataType
{ {
DtInvalid = 0, DtInvalid = 0,


+ 3
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -159,7 +159,10 @@ namespace Tensorflow
"uint32" => TF_DataType.TF_UINT32, "uint32" => TF_DataType.TF_UINT32,
"int64" => TF_DataType.TF_INT64, "int64" => TF_DataType.TF_INT64,
"uint64" => TF_DataType.TF_UINT64, "uint64" => TF_DataType.TF_UINT64,
"float16" => TF_DataType.TF_BFLOAT16,
"float32" => TF_DataType.TF_FLOAT,
"single" => TF_DataType.TF_FLOAT, "single" => TF_DataType.TF_FLOAT,
"float64" => TF_DataType.TF_DOUBLE,
"double" => TF_DataType.TF_DOUBLE, "double" => TF_DataType.TF_DOUBLE,
"complex" => TF_DataType.TF_COMPLEX128, "complex" => TF_DataType.TF_COMPLEX128,
"string" => TF_DataType.TF_STRING, "string" => TF_DataType.TF_STRING,


+ 22
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/WrapperFunction.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Functions;

namespace Tensorflow.Training.Saving.SavedModel
{
/// <summary>
/// A class wraps a concrete function to handle different distributed contexts.
/// </summary>
internal class WrapperFunction: ConcreteFunction
{
public WrapperFunction(ConcreteFunction concrete_function): base(concrete_function.func_graph)
{
this.forward_backward = concrete_function.forward_backward;
this.Outputs = concrete_function.Outputs;
this.ReturnType = concrete_function.ReturnType;
this.OutputStructure = concrete_function.OutputStructure;
this.ArgKeywords = concrete_function.ArgKeywords;
}
}
}

+ 36
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -0,0 +1,36 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Functions;
using Tensorflow.Util;

namespace Tensorflow.Training.Saving.SavedModel
{
public static class function_deserialization
{
public static ConcreteFunction setup_bare_concrete_function(SavedBareConcreteFunction saved_bare_concrete_function,
IDictionary<string, ConcreteFunction> concrete_functions)
{
var concrete_function = concrete_functions[saved_bare_concrete_function.ConcreteFunctionName];
concrete_function.ArgKeywords = saved_bare_concrete_function.ArgumentKeywords.ToList();
concrete_function.NumPositionArgs = saved_bare_concrete_function.AllowedPositionalArguments;

var function_spec = _deserialize_function_spec_as_nonmethod(saved_bare_concrete_function.FunctionSpec);
concrete_function.AddTograph();
return concrete_function;
}

private static FunctionSpec _deserialize_function_spec_as_nonmethod(FunctionSpec function_spec_proto)
{
// TODO(Rinne); revise the implementation.
return new FunctionSpec()
{
Fullargspec = function_spec_proto.Fullargspec,
IsMethod = function_spec_proto.IsMethod,
InputSignature = function_spec_proto.InputSignature,
JitCompile = function_spec_proto.JitCompile
};
}
}
}

+ 25
- 0
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -12,6 +12,7 @@ using static Tensorflow.Binding;
using System.Runtime.CompilerServices; using System.Runtime.CompilerServices;
using Tensorflow.Variables; using Tensorflow.Variables;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.Training.Saving.SavedModel;


namespace Tensorflow namespace Tensorflow
{ {
@@ -307,6 +308,11 @@ namespace Tensorflow
foreach(var (node_id, proto) in _iter_all_nodes()) foreach(var (node_id, proto) in _iter_all_nodes())
{ {
var node = get(node_id); var node = get(node_id);
if(node is null)
{
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`.
continue;
}
if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME) if(proto.SaveableObjects.Keys.Count == 1 && proto.SaveableObjects.First().Key == TrackableUtils.SERIALIZE_TO_TENSORS_NAME)
{ {
// Restore Trackable serialize- and restore-from-tensor functions. // Restore Trackable serialize- and restore-from-tensor functions.
@@ -376,6 +382,13 @@ namespace Tensorflow
} }
else else
{ {
// skip the function and concrete function.
if(proto.KindCase == SavedObject.KindOneofCase.BareConcreteFunction || proto.KindCase == SavedObject.KindOneofCase.Function)
{
nodes[node_id] = null;
node_setters[node_id] = null;
continue;
}
var (node, setter) = _recreate(proto, node_id, nodes); var (node, setter) = _recreate(proto, node_id, nodes);
nodes[node_id] = node; nodes[node_id] = node;
node_setters[node_id] = setter; node_setters[node_id] = setter;
@@ -480,6 +493,11 @@ namespace Tensorflow


foreach(var refer in proto.Children) foreach(var refer in proto.Children)
{ {
if(obj is null)
{
// skip it because now we skip the restoration of `Function` and `ConcreteFunction`.
continue;
}
setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]); setter.Invoke(obj, refer.LocalName, _nodes[refer.NodeId]);
// skip the process of "__call__" // skip the process of "__call__"
} }
@@ -591,6 +609,13 @@ namespace Tensorflow
} }
} }


private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,
Dictionary<Maybe<string, int>, Trackable> dependencies)
{
throw new NotImplementedException();
//var fn = function_deserialization.setup_bare_concrete_function(proto, )
}

// TODO: remove this to a common class. // TODO: remove this to a common class.
public static Action<object, object, object> setattr = (x, y, z) => public static Action<object, object, object> setattr = (x, y, z) =>
{ {


+ 3
- 8
src/TensorFlowNET.Keras/Engine/Functional.FromConfig.cs View File

@@ -24,10 +24,10 @@ namespace Tensorflow.Keras.Engine
/// </summary> /// </summary>
/// <param name="config"></param> /// <param name="config"></param>
/// <returns></returns> /// <returns></returns>
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config)
public static (Tensors, Tensors, Dictionary<string, ILayer>) reconstruct_from_config(ModelConfig config, Dictionary<string, ILayer>? created_layers = null)
{ {
// Layer instances created during the graph reconstruction process. // Layer instances created during the graph reconstruction process.
var created_layers = new Dictionary<string, ILayer>();
created_layers = created_layers ?? new Dictionary<string, ILayer>();
var node_index_map = new Dictionary<(string, int), int>(); var node_index_map = new Dictionary<(string, int), int>();
var node_count_by_layer = new Dictionary<ILayer, int>(); var node_count_by_layer = new Dictionary<ILayer, int>();
var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>(); var unprocessed_nodes = new Dictionary<ILayer, NodeConfig>();
@@ -88,12 +88,7 @@ namespace Tensorflow.Keras.Engine
layer = created_layers[layer_name]; layer = created_layers[layer_name];
else else
{ {
layer = layer_data.ClassName switch
{
"InputLayer" => InputLayer.from_config(layer_data.Config),
"Dense" => Dense.from_config(layer_data.Config),
_ => throw new NotImplementedException("")
};
layer = generic_utils.deserialize_keras_object(layer_data.ClassName, layer_data.Config);


created_layers[layer_name] = layer; created_layers[layer_name] = layer;
} }


+ 1
- 1
src/TensorFlowNET.Keras/Engine/Layer.Serialize.cs View File

@@ -12,7 +12,7 @@ public abstract partial class Layer


public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier; public override string ObjectIdentifier => TrackableSavedModelSaver.ObjectIdentifier;


public string TrackingMetadata => TrackableSavedModelSaver.TrackingMetadata;
public string GetTrackingMetadata() => TrackableSavedModelSaver.TrackingMetadata;


public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null) public override IDictionary<string, Trackable> _trackable_children(SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
{ {


+ 0
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -96,7 +96,6 @@ namespace Tensorflow.Keras.Engine


List<INode> inboundNodes; List<INode> inboundNodes;
public List<INode> InboundNodes => inboundNodes; public List<INode> InboundNodes => inboundNodes;

List<INode> outboundNodes; List<INode> outboundNodes;
public List<INode> OutboundNodes => outboundNodes; public List<INode> OutboundNodes => outboundNodes;




+ 0
- 5
src/TensorFlowNET.Keras/Layers/Core/Dense.cs View File

@@ -85,10 +85,5 @@ namespace Tensorflow.Keras.Layers


return outputs; return outputs;
} }

public static Dense from_config(LayerArgs args)
{
return new Dense(args as DenseArgs);
}
} }
} }

+ 0
- 5
src/TensorFlowNET.Keras/Layers/Core/InputLayer.cs View File

@@ -102,11 +102,6 @@ namespace Tensorflow.Keras.Layers
name: Name); name: Name);
} }


public static InputLayer from_config(LayerArgs args)
{
return new InputLayer(args as InputLayerArgs);
}

public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this); public override SavedModelSaver TrackableSavedModelSaver => new InputLayerSavedModelSaver(this);
} }
} }

+ 3
- 13
src/TensorFlowNET.Keras/Models/ModelsApi.cs View File

@@ -4,6 +4,7 @@ using System.IO;
using System.Text; using System.Text;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel;
using ThirdParty.Tensorflow.Python.Keras.Protobuf; using ThirdParty.Tensorflow.Python.Keras.Protobuf;


namespace Tensorflow.Keras.Models namespace Tensorflow.Keras.Models
@@ -13,20 +14,9 @@ namespace Tensorflow.Keras.Models
public Functional from_config(ModelConfig config) public Functional from_config(ModelConfig config)
=> Functional.from_config(config); => Functional.from_config(config);


public void load_model(string filepath, bool compile = true)
public Model load_model(string filepath, bool compile = true, LoadOptions? options = null)
{ {
var bytes = File.ReadAllBytes(Path.Combine(filepath, "saved_model.pb"));
var saved_mode = SavedModel.Parser.ParseFrom(bytes);
var meta_graph_def = saved_mode.MetaGraphs[0];
var object_graph_def = meta_graph_def.ObjectGraphDef;

bytes = File.ReadAllBytes(Path.Combine(filepath, "keras_metadata.pb"));
var metadata = SavedMetadata.Parser.ParseFrom(bytes);

// Recreate layers and metrics using the info stored in the metadata.
var keras_loader = new KerasObjectLoader(metadata, object_graph_def);
keras_loader.load_layers(compile: compile);
return KerasLoadModelUtils.load_model(filepath, compile: compile, options: options) as Model;
} }
} }
} }

+ 4
- 4
src/TensorFlowNET.Keras/Saving/KerasObjectLoader.cs View File

@@ -164,11 +164,11 @@ namespace Tensorflow.Keras.Saving
{ {
if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer") if (config["layers"][0]["class_name"].ToObject<string>() == "InputLayer")
{ {
layers.Insert(0, InputLayer.from_config(config["layers"][0]["config"].ToObject<InputLayerArgs>()));
layers.Insert(0, new InputLayer(config["layers"][0]["config"].ToObject<InputLayerArgs>()));
} }
else if (config["layers"][0]["config"]["batch_input_shape"] is not null) else if (config["layers"][0]["config"]["batch_input_shape"] is not null)
{ {
// TODO: implement it
// TODO(Rinne): implement it
} }
} }
@@ -192,7 +192,8 @@ namespace Tensorflow.Keras.Saving
else else
{ {
// skip the parameter `created_layers`. // skip the parameter `created_layers`.
var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(config.ToObject<ModelConfig>());
var (inputs, outputs, created_layers) = Functional.reconstruct_from_config(generic_utils.deserialize_model_config(config),
layers.ToDictionary(x => x.Name, x => x as ILayer));
// skip the `model.__init__` // skip the `model.__init__`
(model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>()); (model as Functional).Initialize(inputs, outputs, config["name"].ToObject<string>());
(model as Functional).connect_ancillary_layers(created_layers); (model as Functional).connect_ancillary_layers(created_layers);
@@ -283,7 +284,6 @@ namespace Tensorflow.Keras.Saving


private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json) private (Trackable, Action<object, object, object>) _load_layer(int node_id, string identifier, string metadata_json)
{ {
metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1");
var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json); var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);


if (loaded_nodes.ContainsKey(node_id)) if (loaded_nodes.ContainsKey(node_id))


+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -95,7 +95,7 @@ public partial class KerasSavedModelUtils
BadConsumers = { } BadConsumers = { }
}, },
Identifier = layer.ObjectIdentifier, Identifier = layer.ObjectIdentifier,
Metadata = layer.TrackingMetadata
Metadata = layer.GetTrackingMetadata()
}; };


metadata.Nodes.Add(saved_object); metadata.Nodes.Add(saved_object);


+ 3
- 3
src/TensorFlowNET.Keras/Saving/SavedModel/load.cs View File

@@ -44,7 +44,7 @@ namespace Tensorflow.Keras.Saving.SavedModel
} }
} }


public static Trackable load(string path, bool compile = true, LoadOptions? options = null)
private static Trackable load(string path, bool compile = true, LoadOptions? options = null)
{ {
SavedMetadata metadata = new SavedMetadata(); SavedMetadata metadata = new SavedMetadata();
var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0]; var meta_graph_def = Loader.parse_saved_model(path).MetaGraphs[0];
@@ -82,12 +82,12 @@ namespace Tensorflow.Keras.Saving.SavedModel


if(model is Model && compile) if(model is Model && compile)
{ {
// TODO: implement it.
// TODO(Rinne): implement it.
} }


if (!tf.Context.executing_eagerly()) if (!tf.Context.executing_eagerly())
{ {
// TODO: implement it.
// TODO(Rinne): implement it.
} }


return model; return model;


+ 53
- 1
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -55,7 +55,7 @@ namespace Tensorflow.Keras.Utils
return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance);
} }


public static Layer deserialize_keras_object(string class_name, JObject config)
public static Layer deserialize_keras_object(string class_name, JToken config)
{ {
return class_name switch return class_name switch
{ {
@@ -70,6 +70,58 @@ namespace Tensorflow.Keras.Utils
}; };
} }


public static Layer deserialize_keras_object(string class_name, LayerArgs args)
{
return class_name switch
{
"Sequential" => new Sequential(args as SequentialArgs),
"InputLayer" => new InputLayer(args as InputLayerArgs),
"Flatten" => new Flatten(args as FlattenArgs),
"ELU" => new ELU(args as ELUArgs),
"Dense" => new Dense(args as DenseArgs),
"Softmax" => new Softmax(args as SoftmaxArgs),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
}

public static LayerArgs? deserialize_layer_args(string class_name, JToken config)
{
return class_name switch
{
"Sequential" => config.ToObject<SequentialArgs>(),
"InputLayer" => config.ToObject<InputLayerArgs>(),
"Flatten" => config.ToObject<FlattenArgs>(),
"ELU" => config.ToObject<ELUArgs>(),
"Dense" => config.ToObject<DenseArgs>(),
"Softmax" => config.ToObject<SoftmaxArgs>(),
_ => throw new NotImplementedException($"The deserialization of <{class_name}> has not been supported. Usually it's a miss during the development. " +
$"Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues")
};
}

public static ModelConfig deserialize_model_config(JToken json)
{
ModelConfig config = new ModelConfig();
config.Name = json["name"].ToObject<string>();
config.Layers = new List<LayerConfig>();
var layersToken = json["layers"];
foreach (var token in layersToken)
{
var args = deserialize_layer_args(token["class_name"].ToObject<string>(), token["config"]);
config.Layers.Add(new LayerConfig()
{
Config = args,
Name = token["name"].ToObject<string>(),
ClassName = token["class_name"].ToObject<string>(),
InboundNodes = token["inbound_nodes"].ToObject<List<NodeConfig>>()
});
}
config.InputLayers = json["input_layers"].ToObject<List<NodeConfig>>();
config.OutputLayers = json["output_layers"].ToObject<List<NodeConfig>>();
return config;
}

public static string to_snake_case(string name) public static string to_snake_case(string name)
{ {
return string.Concat(name.Select((x, i) => return string.Concat(name.Select((x, i) =>


+ 6
- 7
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -21,17 +21,16 @@ public class SequentialModelLoad
[TestMethod] [TestMethod]
public void SimpleModelFromSequential() public void SimpleModelFromSequential()
{ {
var model = KerasLoadModelUtils.load_model(@"D:/development/tf.net/tf_test/model.pb");
Debug.Assert(model is Model);
var m = model as Model;
new SequentialModelSave().SimpleModelFromSequential();
var model = keras.models.load_model(@"./pb_simple_sequential");


m.summary();
model.summary();


m.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });


var data_loader = new MnistModelLoader(); var data_loader = new MnistModelLoader();
var num_epochs = 1; var num_epochs = 1;
var batch_size = 50;
var batch_size = 8;


var dataset = data_loader.LoadAsync(new ModelLoadSetting var dataset = data_loader.LoadAsync(new ModelLoadSetting
{ {
@@ -40,6 +39,6 @@ public class SequentialModelLoad
ValidationSize = 50000, ValidationSize = 50000,
}).Result; }).Result;


m.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
} }
} }

+ 6
- 12
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs View File

@@ -1,27 +1,21 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.NumPy;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Diagnostics;
using Tensorflow; using Tensorflow;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras; using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers; using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses; using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using Tensorflow.Operations;
using System.Diagnostics;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;


namespace TensorFlowNET.Keras.UnitTest.SaveModel; namespace TensorFlowNET.Keras.UnitTest.SaveModel;


[TestClass] [TestClass]
public class SequentialModelTest
public class SequentialModelSave
{ {
[TestMethod] [TestMethod]
public void SimpleModelFromAutoCompile() public void SimpleModelFromAutoCompile()
@@ -118,7 +112,7 @@ public class SequentialModelTest
keras.layers.Softmax(1) keras.layers.Softmax(1)
}); });


model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" });
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });


var num_epochs = 1; var num_epochs = 1;
var batch_size = 8; var batch_size = 8;


Loading…
Cancel
Save