From 2ab0bdbc8690b0048eaeb5ab5069e042b1a88d25 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Fri, 3 Feb 2023 19:08:50 +0800 Subject: [PATCH] Add more implementations to the keras part of pb model save. --- .../ArgsDefinition/Activation/SoftmaxArgs.cs | 17 ++- .../ArgsDefinition/AutoSerializeLayerArgs.cs | 19 +++ .../Keras/ArgsDefinition/Core/DenseArgs.cs | 42 +++++- .../ArgsDefinition/Core/InputLayerArgs.cs | 17 ++- .../Keras/ArgsDefinition/DataAdapterArgs.cs | 3 +- .../Keras/ArgsDefinition/DataHandlerArgs.cs | 3 +- .../Keras/ArgsDefinition/LayerArgs.cs | 31 +++-- .../Keras/ArgsDefinition/NodeArgs.cs | 6 +- .../Keras/ArgsDefinition/OptimizerV2Args.cs | 6 +- .../ArgsDefinition/Reshaping/FlattenArgs.cs | 7 +- .../CustomizedActivationJsonConverter.cs | 50 +++++++ .../Common/CustomizedAxisJsonConverter.cs | 48 +++++++ .../CustomizedNodeConfigJsonConverter.cs | 73 ++++++++++ .../Common/CustomizedShapeJsonConverter.cs | 67 ++++++++++ .../Keras/Engine/InputSpec.cs | 31 ++++- src/TensorFlowNET.Core/Keras/Layers/ILayer.cs | 5 +- .../Keras/Saving/IKerasConfig.cs | 15 +++ .../Keras/Saving/LayerConfig.cs | 9 +- .../Keras/Saving/ModelConfig.cs | 9 +- .../Keras/Saving/NodeConfig.cs | 7 +- .../Keras/Saving/TensorShapeConfig.cs | 21 +++ src/TensorFlowNET.Core/NumPy/Axis.cs | 11 +- src/TensorFlowNET.Core/Numpy/Shape.cs | 3 + .../Operations/Initializers/Constant.cs | 10 ++ .../Operations/Initializers/GlorotUniform.cs | 10 +- .../Operations/Initializers/IInitializer.cs | 7 + .../Operations/Initializers/Ones.cs | 7 + .../Operations/Initializers/Orthogonal.cs | 5 + .../Operations/Initializers/RandomNormal.cs | 12 ++ .../Operations/Initializers/RandomUniform.cs | 12 ++ .../Initializers/TruncatedNormal.cs | 11 ++ .../Initializers/VarianceScaling.cs | 13 ++ .../Operations/Initializers/Zeros.cs | 5 + .../Operations/NnOps/RNNCell.cs | 5 +- .../Tensorflow.Binding.csproj | 1 + src/TensorFlowNET.Core/Tensors/dtypes.cs | 18 +++ .../{ITrackable.cs => IWithTrackable.cs} | 2 +- src/TensorFlowNET.Core/Training/Trackable.cs | 2 +- .../Engine/Functional.GetConfig.cs | 31 +++-- src/TensorFlowNET.Keras/Engine/Functional.cs | 18 +++ src/TensorFlowNET.Keras/Engine/Layer.cs | 6 +- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 6 +- .../Layers/Activation/ELU.cs | 1 + .../Layers/Activation/Exponential.cs | 1 + .../Layers/Activation/SELU.cs | 9 +- .../Layers/Attention/Attention.cs | 3 +- .../Layers/Attention/BaseDenseAttention.cs | 3 +- .../Layers/Convolution/Conv2DTranspose.cs | 1 + .../Layers/Convolution/Convolutional.cs | 1 + src/TensorFlowNET.Keras/Layers/Core/Dense.cs | 1 + .../Layers/Core/Embedding.cs | 1 + .../Layers/Cropping/Cropping1D.cs | 1 + .../Layers/Cropping/Cropping2D.cs | 3 +- .../Layers/Cropping/Cropping3D.cs | 3 +- src/TensorFlowNET.Keras/Layers/LayersApi.cs | 4 +- .../Layers/Merging/Concatenate.cs | 1 + .../Layers/Merging/Merge.cs | 1 + .../Normalization/BatchNormalization.cs | 1 + .../Normalization/LayerNormalization.cs | 1 + .../Layers/Reshaping/Permute.cs | 1 + .../Layers/Rnn/SimpleRNN.cs | 1 + .../Layers/Rnn/StackedRNNCells.cs | 3 +- .../Saving/SavedModel/Save.cs | 2 +- .../Saving/SavedModel/layer_serialization.cs | 33 +++-- .../Saving/TensorShapeConfig.cs | 15 --- .../Saving/serialization.cs | 125 ++++++++++++++++++ .../Utils/base_layer_utils.cs | 2 +- .../Utils/generic_utils.cs | 14 +- .../Layers/ModelSaveTest.cs | 5 +- test/TensorFlowNET.Keras.UnitTest/SaveTest.cs | 40 ++++-- 70 files changed, 849 insertions(+), 109 deletions(-) create mode 100644 src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs create mode 100644 src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs create mode 100644 src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs rename src/TensorFlowNET.Core/Training/{ITrackable.cs => IWithTrackable.cs} (82%) delete mode 100644 src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs create mode 100644 src/TensorFlowNET.Keras/Saving/serialization.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs index ca35d75d..a37973bc 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs @@ -1,9 +1,18 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.ArgsDefinition { - public class SoftmaxArgs : LayerArgs { - public Axis axis { get; set; } = -1; - } + public class SoftmaxArgs : LayerArgs + { + [JsonProperty("axis")] + public Axis axis { get; set; } = -1; + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs new file mode 100644 index 00000000..66b34a1a --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs @@ -0,0 +1,19 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.ArgsDefinition +{ + public class AutoSerializeLayerArgs: LayerArgs + { + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] + public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } + } +} diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs index e9b3c2fd..8f4facbd 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs @@ -1,13 +1,18 @@ -using System; +using Newtonsoft.Json; +using System; +using System.Xml.Linq; +using Tensorflow.Operations.Initializers; using static Tensorflow.Binding; namespace Tensorflow.Keras.ArgsDefinition { + // TODO: `activity_regularizer` public class DenseArgs : LayerArgs { /// /// Positive integer, dimensionality of the output space. /// + [JsonProperty("units")] public int Units { get; set; } /// @@ -15,39 +20,74 @@ namespace Tensorflow.Keras.ArgsDefinition /// public Activation Activation { get; set; } + private string _activationName; + [JsonProperty("activation")] + public string ActivationName + { + get + { + if (string.IsNullOrEmpty(_activationName)) + { + return Activation.Method.Name; + } + else + { + return _activationName; + } + } + set + { + _activationName = value; + } + } + /// /// Whether the layer uses a bias vector. /// + [JsonProperty("use_bias")] public bool UseBias { get; set; } = true; /// /// Initializer for the `kernel` weights matrix. /// + [JsonProperty("kernel_initializer")] public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; /// /// Initializer for the bias vector. /// + [JsonProperty("bias_initializer")] public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; /// /// Regularizer function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_regularizer")] public IRegularizer KernelRegularizer { get; set; } /// /// Regularizer function applied to the bias vector. /// + [JsonProperty("bias_regularizer")] public IRegularizer BiasRegularizer { get; set; } /// /// Constraint function applied to the `kernel` weights matrix. /// + [JsonProperty("kernel_constraint")] public Action KernelConstraint { get; set; } /// /// Constraint function applied to the bias vector. /// + [JsonProperty("bias_constraint")] public Action BiasConstraint { get; set; } + + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("trainable")] + public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs index 723109c2..be43e0a6 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs @@ -1,9 +1,22 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; +using Newtonsoft.Json.Serialization; +using Tensorflow.Keras.Common; + +namespace Tensorflow.Keras.ArgsDefinition { public class InputLayerArgs : LayerArgs { + [JsonIgnore] public Tensor InputTensor { get; set; } - public bool Sparse { get; set; } + [JsonProperty("sparse")] + public virtual bool Sparse { get; set; } + [JsonProperty("ragged")] public bool Ragged { get; set; } + [JsonProperty("name")] + public override string Name { get => base.Name; set => base.Name = value; } + [JsonProperty("dtype")] + public override TF_DataType DType { get => base.DType; set => base.DType = value; } + [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] + public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs index f3cca438..8ce1ec65 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs @@ -1,8 +1,9 @@ using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.ArgsDefinition { - public class DataAdapterArgs + public class DataAdapterArgs: IKerasConfig { public Tensor X { get; set; } public Tensor Y { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs index b6e6849b..fd603a85 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs @@ -1,8 +1,9 @@ using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.ArgsDefinition { - public class DataHandlerArgs + public class DataHandlerArgs: IKerasConfig { public Tensor X { get; set; } public Tensor Y { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs index 4df4fb2b..febf1417 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs @@ -1,51 +1,54 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class LayerArgs + [JsonObject(MemberSerialization.OptIn)] + public class LayerArgs: IKerasConfig { /// /// Indicates whether the layer's weights are updated during training /// and whether the layer's updates are run during training. /// - public bool Trainable { get; set; } = true; - - public string Name { get; set; } + public virtual bool Trainable { get; set; } = true; + public virtual string Name { get; set; } /// /// Only applicable to input layers. /// - public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; + public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; /// /// Whether the `call` method can be used to build a TF graph without issues. /// This attribute has no effect if the model is created using the Functional /// API. Instead, `model.dynamic` is determined based on the internal layers. /// - public bool Dynamic { get; set; } = false; + public virtual bool Dynamic { get; set; } = false; /// /// Only applicable to input layers. /// - public Shape InputShape { get; set; } + public virtual Shape InputShape { get; set; } /// /// Only applicable to input layers. /// - public Shape BatchInputShape { get; set; } + public virtual Shape BatchInputShape { get; set; } - public int BatchSize { get; set; } = -1; + public virtual int BatchSize { get; set; } = -1; /// /// Initial weight values. /// - public float[] Weights { get; set; } + public virtual float[] Weights { get; set; } /// /// Regularizer function applied to the output of the layer(its "activation"). /// - public IRegularizer ActivityRegularizer { get; set; } + public virtual IRegularizer ActivityRegularizer { get; set; } - public bool Autocast { get; set; } + public virtual bool Autocast { get; set; } - public bool IsFromConfig { get; set; } + public virtual bool IsFromConfig { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs index 0d9e26ac..ad55ff61 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class NodeArgs + public class NodeArgs: IKerasConfig { public ILayer[] InboundLayers { get; set; } public int[] NodeIndices { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs index e2a0e43c..6256fd32 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs @@ -1,6 +1,8 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.ArgsDefinition { - public class OptimizerV2Args + public class OptimizerV2Args: IKerasConfig { public string Name { get; set; } public float LearningRate { get; set; } = 0.001f; diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs index c2b48cc2..91ffc205 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs @@ -1,7 +1,10 @@ -namespace Tensorflow.Keras.ArgsDefinition +using Newtonsoft.Json; + +namespace Tensorflow.Keras.ArgsDefinition { - public class FlattenArgs : LayerArgs + public class FlattenArgs : AutoSerializeLayerArgs { + [JsonProperty("data_format")] public string DataFormat { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs new file mode 100644 index 00000000..1bc13caf --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs @@ -0,0 +1,50 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedActivationJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Activation); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(""); + token.WriteTo(writer); + } + else if (value is not Activation) + { + throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var token = JToken.FromObject((value as Activation)!.GetType().Name); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + throw new NotImplementedException(); + //var dims = serializer.Deserialize(reader, typeof(string)); + //if (dims is null) + //{ + // throw new ValueError("Cannot deserialize 'null' to `Activation`."); + //} + //return new Shape((long[])(dims!)); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs new file mode 100644 index 00000000..4e190605 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs @@ -0,0 +1,48 @@ +using Newtonsoft.Json.Linq; +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedAxisJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Axis); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(new int[] { }); + token.WriteTo(writer); + } + else if (value is not Axis) + { + throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var token = JToken.FromObject((value as Axis)!.axis); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var axis = serializer.Deserialize(reader, typeof(long[])); + if (axis is null) + { + throw new ValueError("Cannot deserialize 'null' to `Axis`."); + } + return new Axis((int[])(axis!)); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs new file mode 100644 index 00000000..1ad19fc8 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs @@ -0,0 +1,73 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Tensorflow.Keras.Saving; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedNodeConfigJsonConverter : JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(NodeConfig); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if (value is null) + { + var token = JToken.FromObject(null); + token.WriteTo(writer); + } + else if (value is not NodeConfig) + { + throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var config = value as NodeConfig; + var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex }); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var values = serializer.Deserialize(reader, typeof(object[])) as object[]; + if (values is null) + { + throw new ValueError("Cannot deserialize 'null' to `Shape`."); + } + if(values.Length != 3) + { + throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); + } + if (values[0] is not string) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); + } + if (values[1] is not int) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); + } + if (values[2] is not int) + { + throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); + } + return new NodeConfig() + { + Name = values[0] as string, + NodeIndex = (int)values[1], + TensorIndex = (int)values[2] + }; + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs new file mode 100644 index 00000000..300cb2f2 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs @@ -0,0 +1,67 @@ +using Newtonsoft.Json; +using Newtonsoft.Json.Converters; +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Common +{ + public class CustomizedShapeJsonConverter: JsonConverter + { + public override bool CanConvert(Type objectType) + { + return objectType == typeof(Shape); + } + + public override bool CanRead => true; + + public override bool CanWrite => true; + + public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) + { + if(value is null) + { + var token = JToken.FromObject(null); + token.WriteTo(writer); + } + else if(value is not Shape) + { + throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); + } + else + { + var shape = (value as Shape)!; + long?[] dims = new long?[shape.ndim]; + for(int i = 0; i < dims.Length; i++) + { + if (shape.dims[i] == -1) + { + dims[i] = null; + } + else + { + dims[i] = shape.dims[i]; + } + } + var token = JToken.FromObject(dims); + token.WriteTo(writer); + } + } + + public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) + { + var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; + if(dims is null) + { + throw new ValueError("Cannot deserialize 'null' to `Shape`."); + } + long[] convertedDims = new long[dims.Length]; + for(int i = 0; i < dims.Length; i++) + { + convertedDims[i] = dims[i] ?? (-1); + } + return new Shape(convertedDims); + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs index 7280594b..6743935c 100644 --- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs +++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs @@ -16,23 +16,27 @@ using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Engine { /// /// Specifies the ndim, dtype and shape of every input to a layer. /// - public class InputSpec + public class InputSpec: IKerasConfigable { public int? ndim; + public int? max_ndim; public int? min_ndim; Dictionary axes; Shape shape; + TF_DataType dtype; public int[] AllAxisDim; public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, int? ndim = null, int? min_ndim = null, + int? max_ndim = null, Dictionary axes = null, Shape shape = null) { @@ -41,7 +45,9 @@ namespace Tensorflow.Keras.Engine axes = new Dictionary(); this.axes = axes; this.min_ndim = min_ndim; + this.max_ndim = max_ndim; this.shape = shape; + this.dtype = dtype; if (ndim == null && shape != null) this.ndim = shape.ndim; @@ -49,7 +55,30 @@ namespace Tensorflow.Keras.Engine AllAxisDim = axes.Select(x => x.Value).ToArray(); } + public IKerasConfig get_config() + { + return new Config() + { + DType = dtype == TF_DataType.DtInvalid ? null : dtype, + Shape = shape, + Ndim = ndim, + MinNdim = min_ndim, + MaxNdim = max_ndim, + Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value) + }; + } + public override string ToString() => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; + + public class Config: IKerasConfig + { + public TF_DataType? DType { get; set; } + public Shape Shape { get; set; } + public int? Ndim { get; set; } + public int? MinNdim { get;set; } + public int? MaxNdim { get;set; } + public IDictionary Axes { get; set; } + } } } diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index f1ca5632..ebf3358d 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -1,11 +1,12 @@ using System.Collections.Generic; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Training; namespace Tensorflow.Keras { - public interface ILayer: ITrackable + public interface ILayer: IWithTrackable, IKerasConfigable { string Name { get; } bool Trainable { get; } @@ -19,8 +20,8 @@ namespace Tensorflow.Keras List NonTrainableWeights { get; } Shape OutputShape { get; } Shape BatchInputShape { get; } + TensorShapeConfig BuildInputShape { get; } TF_DataType DType { get; } int count_params(); - LayerArgs get_config(); } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs new file mode 100644 index 00000000..1217e1e5 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs @@ -0,0 +1,15 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Keras.Saving +{ + public interface IKerasConfig + { + } + + public interface IKerasConfigable + { + IKerasConfig get_config(); + } +} diff --git a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs index b8b8cab4..4ce290c8 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs @@ -1,4 +1,5 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Keras.ArgsDefinition; @@ -6,11 +7,15 @@ using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Saving { - public class LayerConfig + public class LayerConfig: IKerasConfig { + [JsonProperty("name")] public string Name { get; set; } + [JsonProperty("class_name")] public string ClassName { get; set; } + [JsonProperty("config")] public LayerArgs Config { get; set; } + [JsonProperty("inbound_nodes")] public List InboundNodes { get; set; } } } diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs index abfb235b..cac19180 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs @@ -1,15 +1,20 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; using Tensorflow.Keras.Engine; namespace Tensorflow.Keras.Saving { - public class ModelConfig + public class ModelConfig : IKerasConfig { + [JsonProperty("name")] public string Name { get; set; } + [JsonProperty("layers")] public List Layers { get; set; } + [JsonProperty("input_layers")] public List InputLayers { get; set; } + [JsonProperty("output_layers")] public List OutputLayers { get; set; } public override string ToString() diff --git a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs index 3132248e..20e2fef5 100644 --- a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs +++ b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs @@ -1,10 +1,13 @@ -using System; +using Newtonsoft.Json; +using System; using System.Collections.Generic; using System.Text; +using Tensorflow.Keras.Common; namespace Tensorflow.Keras.Saving { - public class NodeConfig + [JsonConverter(typeof(CustomizedNodeConfigJsonConverter))] + public class NodeConfig : IKerasConfig { public string Name { get; set; } public int NodeIndex { get; set; } diff --git a/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs new file mode 100644 index 00000000..7abcfde2 --- /dev/null +++ b/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs @@ -0,0 +1,21 @@ +using Newtonsoft.Json; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Tensorflow.Keras.Saving +{ + public class TensorShapeConfig + { + [JsonProperty("class_name")] + public string ClassName { get; set; } = "TensorShape"; + [JsonProperty("items")] + public long?[] Items { get; set; } + + public static implicit operator Shape(TensorShapeConfig shape) + => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); + + public static implicit operator TensorShapeConfig(Shape shape) + => new TensorShapeConfig() { Items = shape.dims.Select(x => x == -1 ? null : x).ToArray() }; + } +} diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs index 6c7189df..709ca9b2 100644 --- a/src/TensorFlowNET.Core/NumPy/Axis.cs +++ b/src/TensorFlowNET.Core/NumPy/Axis.cs @@ -14,20 +14,29 @@ limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Keras.Common; namespace Tensorflow { - public record Axis(params int[] axis) + [JsonConverter(typeof(CustomizedAxisJsonConverter))] + public class Axis { + public int[] axis { get; set; } public int size => axis == null ? -1 : axis.Length; public bool IsScalar { get; init; } public int this[int index] => axis[index]; + public Axis(params int[] axis) + { + this.axis = axis; + } + public static implicit operator int[]?(Axis axis) => axis?.axis; diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs index bc79fefc..ecf73586 100644 --- a/src/TensorFlowNET.Core/Numpy/Shape.cs +++ b/src/TensorFlowNET.Core/Numpy/Shape.cs @@ -14,14 +14,17 @@ limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow.Keras.Common; using Tensorflow.NumPy; namespace Tensorflow { + [JsonConverter(typeof(CustomizedShapeJsonConverter))] public class Shape { public int ndim => _dims == null ? -1 : _dims.Length; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs index fdcb5aff..e7e9955c 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Constant : IInitializer @@ -22,11 +24,19 @@ namespace Tensorflow.Operations.Initializers T value; bool _verify_shape; + private readonly Dictionary _config; + + public string ClassName => "Constant"; + public IDictionary Config => _config; + public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) { this.value = value; this.dtype = dtype; _verify_shape = verify_shape; + + _config = new Dictionary(); + _config["value"] = this.value; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs index d97d8830..def1cb7a 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs @@ -14,10 +14,17 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class GlorotUniform : VarianceScaling { + private readonly Dictionary _config; + + public override string ClassName => "GlorotUniform"; + public override IDictionary Config => _config; + public GlorotUniform(float scale = 1.0f, string mode = "FAN_AVG", bool uniform = true, @@ -28,7 +35,8 @@ namespace Tensorflow.Operations.Initializers seed: seed, dtype: dtype) { - + _config = new Dictionary(); + _config["seed"] = _seed; } } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs index 50d4d503..9748b100 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs @@ -14,10 +14,17 @@ limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; +using System.Collections.Generic; + namespace Tensorflow { public interface IInitializer { + [JsonProperty("class_name")] + string ClassName { get; } + [JsonProperty("config")] + IDictionary Config { get; } Tensor Apply(InitializerArgs args); } } diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs index 02d3c93b..3077a1e0 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs @@ -14,12 +14,19 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Ones : IInitializer { private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "Ones"; + public IDictionary Config => new Dictionary(); + public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) { this.dtype = dtype; diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs index 254a7ee7..cdc1c3ed 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs @@ -1,9 +1,14 @@ using System; +using System.Collections.Generic; namespace Tensorflow.Operations.Initializers { public class Orthogonal : IInitializer { + private readonly Dictionary _config; + + public string ClassName => "Orthogonal"; + public IDictionary Config => throw new NotImplementedException(); public Tensor Apply(InitializerArgs args) { throw new NotImplementedException(); diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs index 029b311b..21fa7e2b 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class RandomNormal : IInitializer @@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers private int? seed; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "RandomNormal"; + public IDictionary Config => _config; + public RandomNormal(float mean = 0.0f, float stddev = 0.05f, int? seed = null, @@ -32,6 +39,11 @@ namespace Tensorflow.Operations.Initializers this.stddev = stddev; this.seed = seed; this.dtype = dtype; + + _config = new Dictionary(); + _config["mean"] = this.mean; + _config["stddev"] = this.stddev; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs index a49d5921..87404708 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class RandomUniform : IInitializer @@ -23,12 +25,22 @@ namespace Tensorflow.Operations.Initializers private float maxval; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "RandomUniform"; + public IDictionary Config => _config; + public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) { this.dtype = dtype; this.minval = minval; this.maxval = maxval; this.seed = seed; + + _config = new Dictionary(); + _config["minval"] = this.minval; + _config["maxval"] = this.maxval; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs index 048c11e7..c1c3e999 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class TruncatedNormal : IInitializer @@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers private int? seed; private TF_DataType dtype; + private readonly Dictionary _config; + + public string ClassName => "TruncatedNormal"; + public IDictionary Config => _config; + public TruncatedNormal(float mean = 0.0f, float stddev = 1.0f, int? seed = null, @@ -32,6 +39,10 @@ namespace Tensorflow.Operations.Initializers this.stddev = stddev; this.seed = seed; this.dtype = dtype; + _config = new Dictionary(); + _config["mean"] = this.mean; + _config["stddev"] = this.stddev; + _config["seed"] = this.seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs index d313f4c9..f104e8e8 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs @@ -15,7 +15,9 @@ ******************************************************************************/ using System; +using System.Collections.Generic; using System.Linq; +using System.Linq.Expressions; namespace Tensorflow.Operations.Initializers { @@ -30,6 +32,11 @@ namespace Tensorflow.Operations.Initializers protected int? _seed; protected TF_DataType _dtype; protected bool _uniform; + private readonly Dictionary _config; + + public virtual string ClassName => "VarianceScaling"; + + public virtual IDictionary Config => _config; public VarianceScaling(float factor = 2.0f, string mode = "FAN_IN", @@ -50,6 +57,12 @@ namespace Tensorflow.Operations.Initializers _seed = seed; _dtype = dtype; _uniform = uniform; + + _config = new(); + _config["scale"] = _scale; + _config["mode"] = _mode; + _config["distribution"] = _distribution; + _config["seed"] = _seed; } public Tensor Apply(InitializerArgs args) diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs index 5d045292..c4ed25a1 100644 --- a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs +++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs @@ -14,6 +14,8 @@ limitations under the License. ******************************************************************************/ +using System.Collections.Generic; + namespace Tensorflow.Operations.Initializers { public class Zeros : IInitializer @@ -21,6 +23,9 @@ namespace Tensorflow.Operations.Initializers Shape shape; TF_DataType dtype; + public string ClassName => "Zeros"; + public IDictionary Config => new Dictionary(); + public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) { this.shape = shape; diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 734f2608..c29ed47b 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -20,6 +20,7 @@ using Tensorflow.Keras; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; using Tensorflow.Operations; using Tensorflow.Train; using Tensorflow.Util; @@ -76,6 +77,8 @@ namespace Tensorflow public Shape BatchInputShape => throw new NotImplementedException(); + public TensorShapeConfig BuildInputShape => throw new NotImplementedException(); + public TF_DataType DType => throw new NotImplementedException(); protected bool built = false; public bool Built => built; @@ -144,7 +147,7 @@ namespace Tensorflow throw new NotImplementedException(); } - public LayerArgs get_config() + public IKerasConfig get_config() { throw new NotImplementedException(); } diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj index 0ebe61d0..7068ed47 100644 --- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj +++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj @@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io + diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs index 372ac676..deeb9e4b 100644 --- a/src/TensorFlowNET.Core/Tensors/dtypes.cs +++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs @@ -202,6 +202,24 @@ namespace Tensorflow _ => type.ToString() }; + public static string as_python_name(this TF_DataType type) + => type switch + { + TF_DataType.TF_STRING => "str", + TF_DataType.TF_UINT8 => "uint8", + TF_DataType.TF_INT8 => "int8", + TF_DataType.TF_UINT32 => "uint32", + TF_DataType.TF_INT32 => "int32", + TF_DataType.TF_UINT64 => "uint64", + TF_DataType.TF_INT64 => "int64", + TF_DataType.TF_FLOAT => "float32", + TF_DataType.TF_DOUBLE => "float64", + TF_DataType.TF_BOOL => "bool", + TF_DataType.TF_RESOURCE => "resource", + TF_DataType.TF_VARIANT => "variant", + _ => type.ToString() + }; + public static int get_datatype_size(this TF_DataType type) => type.as_base_dtype() switch { diff --git a/src/TensorFlowNET.Core/Training/ITrackable.cs b/src/TensorFlowNET.Core/Training/IWithTrackable.cs similarity index 82% rename from src/TensorFlowNET.Core/Training/ITrackable.cs rename to src/TensorFlowNET.Core/Training/IWithTrackable.cs index e4ef2c8f..87eda879 100644 --- a/src/TensorFlowNET.Core/Training/ITrackable.cs +++ b/src/TensorFlowNET.Core/Training/IWithTrackable.cs @@ -5,7 +5,7 @@ using Tensorflow.Train; namespace Tensorflow.Training { - public interface ITrackable + public interface IWithTrackable { Trackable GetTrackable(); } diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs index 434d51b6..132571f2 100644 --- a/src/TensorFlowNET.Core/Training/Trackable.cs +++ b/src/TensorFlowNET.Core/Training/Trackable.cs @@ -26,7 +26,7 @@ using static Tensorflow.Binding; namespace Tensorflow.Train { - public abstract class Trackable: ITrackable + public abstract class Trackable: IWithTrackable { /// /// Corresponding to tensorflow/python/trackable/constants.py diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs index a221444b..3aeb3200 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine { public partial class Functional { - public ModelConfig get_config() + public override IKerasConfig get_config() { return get_network_config(); } @@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine { Name = name }; - + var node_conversion_map = new Dictionary(); foreach (var layer in _self_tracked_trackables) { @@ -42,23 +42,26 @@ namespace Tensorflow.Keras.Engine } var layer_configs = new List(); - foreach (var layer in _self_tracked_trackables) + using (SharedObjectSavingScope.Enter()) { - var filtered_inbound_nodes = new List(); - foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) + foreach (var layer in _self_tracked_trackables) { - var node_key = _make_node_key(layer.Name, original_node_index); - if (NetworkNodes.Contains(node_key) && !node.is_input) + var filtered_inbound_nodes = new List(); + foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) { - var node_data = node.serialize(_make_node_key, node_conversion_map); - filtered_inbound_nodes.append(node_data); + var node_key = _make_node_key(layer.Name, original_node_index); + if (NetworkNodes.Contains(node_key) && !node.is_input) + { + var node_data = node.serialize(_make_node_key, node_conversion_map); + filtered_inbound_nodes.append(node_data); + } } - } - var layer_config = generic_utils.serialize_layer_to_config(layer); - layer_config.Name = layer.Name; - layer_config.InboundNodes = filtered_inbound_nodes; - layer_configs.Add(layer_config); + var layer_config = generic_utils.serialize_layer_to_config(layer); + layer_config.Name = layer.Name; + layer_config.InboundNodes = filtered_inbound_nodes; + layer_configs.Add(layer_config); + } } config.Layers = layer_configs; diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs index 7c8812ad..44eaef53 100644 --- a/src/TensorFlowNET.Keras/Engine/Functional.cs +++ b/src/TensorFlowNET.Keras/Engine/Functional.cs @@ -70,6 +70,7 @@ namespace Tensorflow.Keras.Engine this.inputs = inputs; this.outputs = outputs; built = true; + _buildInputShape = inputs.shape; if (outputs.Any(x => x.KerasHistory == null)) base_layer_utils.create_keras_history(outputs); @@ -357,5 +358,22 @@ namespace Tensorflow.Keras.Engine return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) .ToDictionary(x => x.Key, x => x.Value); } + + protected override void _init_set_name(string name, bool zero_based = true) + { + if (string.IsNullOrEmpty(name)) + { + string class_name = GetType().Name; + if (this.GetType() == typeof(Functional)) + { + class_name = "Model"; + } + this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based); + } + else + { + this.name = name; + } + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index a2f92ba8..31b37d68 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -61,6 +61,7 @@ namespace Tensorflow.Keras.Engine /// Provides information about which inputs are compatible with the layer. /// protected InputSpec inputSpec; + public InputSpec InputSpec => inputSpec; bool dynamic = true; public bool SupportsMasking { get; set; } protected List _trainable_weights; @@ -79,6 +80,8 @@ namespace Tensorflow.Keras.Engine protected bool computePreviousMask; protected List updates; public Shape BatchInputShape => args.BatchInputShape; + protected TensorShapeConfig _buildInputShape = null; + public TensorShapeConfig BuildInputShape => _buildInputShape; List inboundNodes; public List InboundNodes => inboundNodes; @@ -223,6 +226,7 @@ namespace Tensorflow.Keras.Engine public virtual void build(Shape input_shape) { + _buildInputShape = input_shape; built = true; } @@ -310,7 +314,7 @@ namespace Tensorflow.Keras.Engine public List Variables => weights; - public virtual LayerArgs get_config() + public virtual IKerasConfig get_config() => args; } } diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index 59b205e4..85da920e 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using Tensorflow.Functions; using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.ModelSaving; @@ -30,7 +31,10 @@ namespace Tensorflow.Keras.Engine } else { - KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + using (SharedObjectSavingScope.Enter()) + { + KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); + } } } } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs index 6e790a26..45f64720 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs @@ -25,6 +25,7 @@ namespace Tensorflow.Keras.Layers { { throw new ValueError("Alpha must be a number greater than 0."); } + _buildInputShape = input_shape; built = true; } diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs index aba175de..2fd2caee 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs @@ -14,6 +14,7 @@ namespace Tensorflow.Keras.Layers { } public override void build(Shape input_shape) { + _buildInputShape = input_shape; built = true; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs index b12d7dee..1ef8d0e5 100644 --- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs +++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs @@ -16,10 +16,11 @@ namespace Tensorflow.Keras.Layers { // SELU has no arguments } public override void build(Shape input_shape) { - if ( alpha < 0f ) { - throw new ValueError("Alpha must be a number greater than 0."); - } - built = true; + if ( alpha < 0f ) { + throw new ValueError("Alpha must be a number greater than 0."); + } + _buildInputShape = input_shape; + built = true; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { Tensor output = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs index 6f6dd7e8..c5131630 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs @@ -4,6 +4,7 @@ using System.Collections; using System.Collections.Generic; using System.Linq; using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers { @@ -146,7 +147,7 @@ namespace Tensorflow.Keras.Layers return scores; } - public override LayerArgs get_config() => this.args; + public override IKerasConfig get_config() => this.args; //var config = new Dictionary { // { // "use_scale", diff --git a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs index 3f618b5d..1348e19c 100644 --- a/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs +++ b/src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs @@ -5,6 +5,7 @@ using static Tensorflow.KerasApi; using System; using System.Collections.Generic; using System.Linq; +using Tensorflow.Keras.Saving; /// /// Base class for attention layers that can be used in sequence DNN/CNN models. @@ -252,6 +253,6 @@ namespace Tensorflow.Keras.Layers return tf.logical_and(x, y); } - public override LayerArgs get_config() => this.args; + public override IKerasConfig get_config() => this.args; } } diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs index e0a337ca..b8286be6 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs @@ -49,6 +49,7 @@ namespace Tensorflow.Keras.Layers initializer: bias_initializer, trainable: true); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs index 912a429b..933aa9cf 100644 --- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs +++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs @@ -98,6 +98,7 @@ namespace Tensorflow.Keras.Layers name: tf_op_name); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs index e4c22745..ca8007d0 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs @@ -43,6 +43,7 @@ namespace Tensorflow.Keras.Layers public override void build(Shape input_shape) { + _buildInputShape = input_shape; var last_dim = input_shape.dims.Last(); var axes = new Dictionary(); axes[-1] = (int)last_dim; diff --git a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs index 79f4e5ce..606f387b 100644 --- a/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs +++ b/src/TensorFlowNET.Keras/Layers/Core/Embedding.cs @@ -62,6 +62,7 @@ namespace Tensorflow.Keras.Layers name: "embeddings"); tf.Context.graph_mode(); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs index 45f5bf0f..44b338c2 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs @@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Layers { throw new ValueError("The `cropping` argument must be a tuple of 2 integers."); } built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs index 6cb03e1e..1f33ee3a 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs @@ -13,7 +13,8 @@ namespace Tensorflow.Keras.Layers { this.args = args; } public override void build(Shape input_shape) { - built = true; + built = true; + _buildInputShape = input_shape; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { Tensor output = inputs; diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs index 2d6751bf..838a5043 100644 --- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs +++ b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs @@ -12,7 +12,8 @@ namespace Tensorflow.Keras.Layers { } public override void build(Shape input_shape) { - built = true; + built = true; + _buildInputShape = input_shape; } protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 50c66be7..c1ec0ddc 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -300,7 +300,8 @@ namespace Tensorflow.Keras.Layers => new Dense(new DenseArgs { Units = units, - Activation = GetActivationByName("linear") + Activation = GetActivationByName("linear"), + ActivationName = "linear" }); /// @@ -321,6 +322,7 @@ namespace Tensorflow.Keras.Layers { Units = units, Activation = GetActivationByName(activation), + ActivationName = activation, InputShape = input_shape }); diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs index 5f821760..da7e857a 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs @@ -37,6 +37,7 @@ namespace Tensorflow.Keras.Layers }).ToArray(); shape_set.Add(shape); }*/ + _buildInputShape = input_shape; } protected override Tensors _merge_function(Tensors inputs) diff --git a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs index 0363d58f..3cd43af9 100644 --- a/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs +++ b/src/TensorFlowNET.Keras/Layers/Merging/Merge.cs @@ -17,6 +17,7 @@ namespace Tensorflow.Keras.Layers public override void build(Shape input_shape) { // output_shape = input_shape.dims[1^]; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs index dac92f81..c0b16c81 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs @@ -118,6 +118,7 @@ namespace Tensorflow.Keras.Layers throw new NotImplementedException("build when renorm is true"); built = true; + _buildInputShape = input_shape; } public override Shape ComputeOutputShape(Shape input_shape) diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs index 5eebd735..e19b9c30 100644 --- a/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs +++ b/src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs @@ -81,6 +81,7 @@ namespace Tensorflow.Keras.Layers _fused = _fused_can_be_used(ndims); built = true; + _buildInputShape = input_shape; } bool _fused_can_be_used(int ndims) diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs index 868506b6..8e7a19a9 100644 --- a/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs +++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs @@ -24,6 +24,7 @@ namespace Tensorflow.Keras.Layers { permute = new int[input_shape.rank]; dims.CopyTo(permute, 1); built = true; + _buildInputShape = input_shape; } protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) { diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs index c8366ff4..38abe2a7 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs @@ -18,6 +18,7 @@ namespace Tensorflow.Keras.Layers.Rnn public override void build(Shape input_shape) { var input_dim = input_shape[-1]; + _buildInputShape = input_shape; kernel = add_weight("kernel", (input_shape[-1], args.Units), initializer: args.KernelInitializer diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs index eead274a..20962df1 100644 --- a/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs +++ b/src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs @@ -4,6 +4,7 @@ using System.ComponentModel; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Layers.Rnn { @@ -136,7 +137,7 @@ namespace Tensorflow.Keras.Layers.Rnn // self.built = True } - public override LayerArgs get_config() + public override IKerasConfig get_config() { throw new NotImplementedException(); //def get_config(self): diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs index 4ff8f02f..9d1c9609 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs @@ -79,7 +79,7 @@ public partial class KerasSavedModelUtils var path = node_paths[node]; string node_path; - if (path is null) + if (path is null || path.Count() == 0) { node_path = "root"; } diff --git a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs index 655127af..8675ea65 100644 --- a/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs +++ b/src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs @@ -1,6 +1,7 @@ using System.Collections.Generic; using Newtonsoft.Json; using Newtonsoft.Json.Linq; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Engine; using Tensorflow.Keras.Layers; using Tensorflow.Keras.Utils; @@ -85,31 +86,38 @@ public class LayerSavedModelSaver: SavedModelSaver JObject metadata = new JObject(); metadata["name"] = _layer.Name; metadata["trainable"] = _layer.Trainable; - // metadata["expects_training_arg"] = _obj._expects_training_arg; - // metadata["dtype"] = policy.serialize(_obj._dtype_policy) + // TODO: implement `expects_training_arg`. + metadata["expects_training_arg"] = false; + metadata["dtype"] = _layer.DType.as_python_name(); metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); // metadata["stateful"] = _obj.stateful; // metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; metadata["autocast"] = _layer.AutoCast; - var temp = JObject.FromObject(get_serialized(_layer)); - metadata.Merge(temp, new JsonMergeSettings + if(_layer.InputSpec is not null) + { + metadata["input_spec"] = generic_utils.serialize_keras_object(_layer.InputSpec); + } + + metadata.Merge(get_serialized(_layer), new JsonMergeSettings { // Handle conflicts by using values from obj2 MergeArrayHandling = MergeArrayHandling.Merge }); // skip the check of `input_spec` and `build_input_shape` for the lack of members. // skip the check of `activity_regularizer` for the type problem. + if(_layer.BuildInputShape is not null) + { + metadata["build_input_shape"] = JToken.FromObject(_layer.BuildInputShape); + } return metadata.ToString(); } } - public static IDictionary get_serialized(Layer obj) + public static JObject get_serialized(Layer obj) { - // TODO: complete the implmentation (need to revise `get_config`). - return new Dictionary(); - //return generic_utils.serialize_keras_object(obj); + return generic_utils.serialize_keras_object(obj); } } @@ -135,18 +143,19 @@ public class InputLayerSavedModelSaver: SavedModelSaver { get { - if(_obj is not Layer) + if(_obj is not InputLayer) { throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); } - var layer = (Layer)_obj; + var layer = (InputLayer)_obj; + var config = (layer.get_config() as InputLayerArgs)!; var info = new { class_name = layer.GetType().Name, name = layer.Name, dtype = layer.DType, - //sparse = layer.sparse, - //ragged = layer.ragged, + sparse = config.Sparse, + ragged = config.Ragged, batch_input_shape = layer.BatchInputShape, config = layer.get_config() }; diff --git a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs deleted file mode 100644 index 4c2ecc0d..00000000 --- a/src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs +++ /dev/null @@ -1,15 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; - -namespace Tensorflow.Keras.Saving -{ - public class TensorShapeConfig - { - public string ClassName { get; set; } - public int?[] Items { get; set; } - - public static implicit operator Shape(TensorShapeConfig shape) - => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); - } -} diff --git a/src/TensorFlowNET.Keras/Saving/serialization.cs b/src/TensorFlowNET.Keras/Saving/serialization.cs new file mode 100644 index 00000000..d5e46d11 --- /dev/null +++ b/src/TensorFlowNET.Keras/Saving/serialization.cs @@ -0,0 +1,125 @@ +using Newtonsoft.Json.Linq; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Text; +using Tensorflow.Keras.Saving.SavedModel; + +namespace Tensorflow.Keras.Saving +{ + // TODO: make it thread safe. + public class SharedObjectSavingScope: IDisposable + { + private class WeakReferenceEqualityComparer: IEqualityComparer> + { + public bool Equals(WeakReference x, WeakReference y) + { + if(!x.TryGetTarget(out var tx)) + { + return false; + } + if(!y.TryGetTarget(out var ty)) + { + return false; + } + return tx.Equals(ty); + } + public int GetHashCode(WeakReference obj) + { + if (!obj.TryGetTarget(out var w)) + { + return 0; + } + return w.GetHashCode(); + } + } + private static SharedObjectSavingScope? _instance = null; + private readonly Dictionary, int> _shared_object_ids= new Dictionary, int>(); + private int _currentId = 0; + /// + /// record how many times the scope is nested. + /// + private int _nestedDepth = 0; + private SharedObjectSavingScope() + { + + } + + public static SharedObjectSavingScope Enter() + { + if(_instance is not null) + { + _instance._nestedDepth++; + return _instance; + } + else + { + _instance = new SharedObjectSavingScope(); + _instance._nestedDepth++; + return _instance; + } + } + + public static SharedObjectSavingScope GetScope() + { + return _instance; + } + + public int GetId(object? obj) + { + if(obj is null) + { + return _currentId++; + } + var maybe_key = _shared_object_ids.Keys.SingleOrDefault(x => new WeakReferenceEqualityComparer().Equals(x, new WeakReference(obj))); + if (maybe_key is not null) + { + return _shared_object_ids[maybe_key]; + } + _shared_object_ids[new WeakReference(obj)] = _currentId++; + return _currentId; + } + + public void Dispose() + { + _nestedDepth--; + if(_nestedDepth== 0) + { + _instance = null; + } + } + } + + public static class serialize_utils + { + public static readonly string SHARED_OBJECT_KEY = "shared_object_id"; + /// + /// Returns the serialization of the class with the given config. + /// + /// + /// + /// + /// + /// + public static JObject serialize_keras_class_and_config(string class_name, JToken config, object? obj = null, int? shared_object_id = null) + { + JObject res = new JObject(); + res["class_name"] = class_name; + res["config"] = config; + + if(shared_object_id is not null) + { + res[SHARED_OBJECT_KEY] = shared_object_id!; + } + + var scope = SharedObjectSavingScope.GetScope(); + if(scope is not null && obj is not null) + { + res[SHARED_OBJECT_KEY] = scope.GetId(obj); + } + + return res; + } + } +} diff --git a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs index 1e6ce409..d845f3ca 100644 --- a/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/base_layer_utils.cs @@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Utils } /// - /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. + /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. (correponding to `backend.unique_object_name` of python.) /// /// /// diff --git a/src/TensorFlowNET.Keras/Utils/generic_utils.cs b/src/TensorFlowNET.Keras/Utils/generic_utils.cs index 68903eb2..730a33e3 100644 --- a/src/TensorFlowNET.Keras/Utils/generic_utils.cs +++ b/src/TensorFlowNET.Keras/Utils/generic_utils.cs @@ -14,10 +14,14 @@ limitations under the License. ******************************************************************************/ +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using System; using System.Collections; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; +using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.Saving; namespace Tensorflow.Keras.Utils @@ -32,13 +36,21 @@ namespace Tensorflow.Keras.Utils public static LayerConfig serialize_layer_to_config(ILayer instance) { var config = instance.get_config(); + Debug.Assert(config is LayerArgs); return new LayerConfig { - Config = config, + Config = config as LayerArgs, ClassName = instance.GetType().Name }; } + public static JObject serialize_keras_object(IKerasConfigable instance) + { + var config = JToken.FromObject(instance.get_config()); + // TODO: change the class_name to registered name, instead of system class name. + return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); + } + public static string to_snake_case(string name) { return string.Concat(name.Select((x, i) => diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs index 0a1098af..67e8ff79 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs @@ -1,6 +1,8 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using Tensorflow.Keras.Engine; +using System.Diagnostics; using static Tensorflow.KerasApi; +using Tensorflow.Keras.Saving; namespace TensorFlowNET.Keras.UnitTest { @@ -15,7 +17,8 @@ namespace TensorFlowNET.Keras.UnitTest { var model = GetFunctionalModel(); var config = model.get_config(); - var new_model = keras.models.from_config(config); + Debug.Assert(config is ModelConfig); + var new_model = keras.models.from_config(config as ModelConfig); Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs index 0f34ff10..90d0a48a 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs @@ -15,17 +15,14 @@ using Tensorflow.Keras.Layers; using Tensorflow.Keras.Losses; using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Optimizers; +using Tensorflow.Operations; namespace TensorFlowNET.Keras.UnitTest; -// class MNISTLoader -// { -// public MNISTLoader() -// { -// var mnist = new MnistModelLoader() -// -// } -// } +public static class AutoGraphExtension +{ + +} [TestClass] public class SaveTest @@ -42,6 +39,8 @@ public class SaveTest model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); + var g = ops.get_default_graph(); + var data_loader = new MnistModelLoader(); var num_epochs = 1; var batch_size = 50; @@ -50,11 +49,34 @@ public class SaveTest { TrainDir = "mnist", OneHot = false, - ValidationSize = 0, + ValidationSize = 50000, }).Result; model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); } + + [TestMethod] + public void Temp() + { + var graph = new Graph(); + var g = graph.as_default(); + //var input_tensor = array_ops.placeholder(TF_DataType.TF_FLOAT, new int[] { 1 }, "test_string_tensor"); + var input_tensor = tf.placeholder(tf.int32, new int[] { 1 }, "aa"); + var wrapped_func = tf.autograph.to_graph(func); + var res = wrapped_func(input_tensor); + g.Exit(); + } + + private Tensor func(Tensor tensor) + { + return gen_ops.neg(tensor); + //return array_ops.identity(tensor); + //tf.device("cpu:0"); + //using (ops.control_dependencies(new object[] { res.op })) + //{ + // return array_ops.identity(tensor); + //} + } } \ No newline at end of file