diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs
index ca35d75d..a37973bc 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs
@@ -1,9 +1,18 @@
-using System;
+using Newtonsoft.Json;
+using System;
using System.Collections.Generic;
using System.Text;
namespace Tensorflow.Keras.ArgsDefinition {
- public class SoftmaxArgs : LayerArgs {
- public Axis axis { get; set; } = -1;
- }
+ public class SoftmaxArgs : LayerArgs
+ {
+ [JsonProperty("axis")]
+ public Axis axis { get; set; } = -1;
+ [JsonProperty("name")]
+ public override string Name { get => base.Name; set => base.Name = value; }
+ [JsonProperty("trainable")]
+ public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
+ [JsonProperty("dtype")]
+ public override TF_DataType DType { get => base.DType; set => base.DType = value; }
+ }
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
new file mode 100644
index 00000000..66b34a1a
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
@@ -0,0 +1,19 @@
+using Newtonsoft.Json;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class AutoSerializeLayerArgs: LayerArgs
+ {
+ [JsonProperty("name")]
+ public override string Name { get => base.Name; set => base.Name = value; }
+ [JsonProperty("dtype")]
+ public override TF_DataType DType { get => base.DType; set => base.DType = value; }
+ [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
+ public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
+ [JsonProperty("trainable")]
+ public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs
index e9b3c2fd..8f4facbd 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs
@@ -1,13 +1,18 @@
-using System;
+using Newtonsoft.Json;
+using System;
+using System.Xml.Linq;
+using Tensorflow.Operations.Initializers;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.ArgsDefinition
{
+ // TODO: `activity_regularizer`
public class DenseArgs : LayerArgs
{
///
/// Positive integer, dimensionality of the output space.
///
+ [JsonProperty("units")]
public int Units { get; set; }
///
@@ -15,39 +20,74 @@ namespace Tensorflow.Keras.ArgsDefinition
///
public Activation Activation { get; set; }
+ private string _activationName;
+ [JsonProperty("activation")]
+ public string ActivationName
+ {
+ get
+ {
+ if (string.IsNullOrEmpty(_activationName))
+ {
+ return Activation.Method.Name;
+ }
+ else
+ {
+ return _activationName;
+ }
+ }
+ set
+ {
+ _activationName = value;
+ }
+ }
+
///
/// Whether the layer uses a bias vector.
///
+ [JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
///
/// Initializer for the `kernel` weights matrix.
///
+ [JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
///
/// Initializer for the bias vector.
///
+ [JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
///
/// Regularizer function applied to the `kernel` weights matrix.
///
+ [JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; }
///
/// Regularizer function applied to the bias vector.
///
+ [JsonProperty("bias_regularizer")]
public IRegularizer BiasRegularizer { get; set; }
///
/// Constraint function applied to the `kernel` weights matrix.
///
+ [JsonProperty("kernel_constraint")]
public Action KernelConstraint { get; set; }
///
/// Constraint function applied to the bias vector.
///
+ [JsonProperty("bias_constraint")]
public Action BiasConstraint { get; set; }
+
+ [JsonProperty("name")]
+ public override string Name { get => base.Name; set => base.Name = value; }
+ [JsonProperty("dtype")]
+ public override TF_DataType DType { get => base.DType; set => base.DType = value; }
+ [JsonProperty("trainable")]
+ public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs
index 723109c2..be43e0a6 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs
@@ -1,9 +1,22 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Newtonsoft.Json;
+using Newtonsoft.Json.Serialization;
+using Tensorflow.Keras.Common;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
public class InputLayerArgs : LayerArgs
{
+ [JsonIgnore]
public Tensor InputTensor { get; set; }
- public bool Sparse { get; set; }
+ [JsonProperty("sparse")]
+ public virtual bool Sparse { get; set; }
+ [JsonProperty("ragged")]
public bool Ragged { get; set; }
+ [JsonProperty("name")]
+ public override string Name { get => base.Name; set => base.Name = value; }
+ [JsonProperty("dtype")]
+ public override TF_DataType DType { get => base.DType; set => base.DType = value; }
+ [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
+ public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
index f3cca438..8ce1ec65 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
@@ -1,8 +1,9 @@
using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Saving;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class DataAdapterArgs
+ public class DataAdapterArgs: IKerasConfig
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
index b6e6849b..fd603a85 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
@@ -1,8 +1,9 @@
using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Saving;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class DataHandlerArgs
+ public class DataHandlerArgs: IKerasConfig
{
public Tensor X { get; set; }
public Tensor Y { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
index 4df4fb2b..febf1417 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
@@ -1,51 +1,54 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Newtonsoft.Json;
+using Tensorflow.Keras.Saving;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
- public class LayerArgs
+ [JsonObject(MemberSerialization.OptIn)]
+ public class LayerArgs: IKerasConfig
{
///
/// Indicates whether the layer's weights are updated during training
/// and whether the layer's updates are run during training.
///
- public bool Trainable { get; set; } = true;
-
- public string Name { get; set; }
+ public virtual bool Trainable { get; set; } = true;
+ public virtual string Name { get; set; }
///
/// Only applicable to input layers.
///
- public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT;
+ public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT;
///
/// Whether the `call` method can be used to build a TF graph without issues.
/// This attribute has no effect if the model is created using the Functional
/// API. Instead, `model.dynamic` is determined based on the internal layers.
///
- public bool Dynamic { get; set; } = false;
+ public virtual bool Dynamic { get; set; } = false;
///
/// Only applicable to input layers.
///
- public Shape InputShape { get; set; }
+ public virtual Shape InputShape { get; set; }
///
/// Only applicable to input layers.
///
- public Shape BatchInputShape { get; set; }
+ public virtual Shape BatchInputShape { get; set; }
- public int BatchSize { get; set; } = -1;
+ public virtual int BatchSize { get; set; } = -1;
///
/// Initial weight values.
///
- public float[] Weights { get; set; }
+ public virtual float[] Weights { get; set; }
///
/// Regularizer function applied to the output of the layer(its "activation").
///
- public IRegularizer ActivityRegularizer { get; set; }
+ public virtual IRegularizer ActivityRegularizer { get; set; }
- public bool Autocast { get; set; }
+ public virtual bool Autocast { get; set; }
- public bool IsFromConfig { get; set; }
+ public virtual bool IsFromConfig { get; set; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs
index 0d9e26ac..ad55ff61 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs
@@ -1,6 +1,8 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Tensorflow.Keras.Saving;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
- public class NodeArgs
+ public class NodeArgs: IKerasConfig
{
public ILayer[] InboundLayers { get; set; }
public int[] NodeIndices { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs
index e2a0e43c..6256fd32 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs
@@ -1,6 +1,8 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Tensorflow.Keras.Saving;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
- public class OptimizerV2Args
+ public class OptimizerV2Args: IKerasConfig
{
public string Name { get; set; }
public float LearningRate { get; set; } = 0.001f;
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs
index c2b48cc2..91ffc205 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs
@@ -1,7 +1,10 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Newtonsoft.Json;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
- public class FlattenArgs : LayerArgs
+ public class FlattenArgs : AutoSerializeLayerArgs
{
+ [JsonProperty("data_format")]
public string DataFormat { get; set; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs
new file mode 100644
index 00000000..1bc13caf
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs
@@ -0,0 +1,50 @@
+using Newtonsoft.Json;
+using Newtonsoft.Json.Converters;
+using Newtonsoft.Json.Linq;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.Common
+{
+ public class CustomizedActivationJsonConverter : JsonConverter
+ {
+ public override bool CanConvert(Type objectType)
+ {
+ return objectType == typeof(Activation);
+ }
+
+ public override bool CanRead => true;
+
+ public override bool CanWrite => true;
+
+ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
+ {
+ if (value is null)
+ {
+ var token = JToken.FromObject("");
+ token.WriteTo(writer);
+ }
+ else if (value is not Activation)
+ {
+ throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}.");
+ }
+ else
+ {
+ var token = JToken.FromObject((value as Activation)!.GetType().Name);
+ token.WriteTo(writer);
+ }
+ }
+
+ public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
+ {
+ throw new NotImplementedException();
+ //var dims = serializer.Deserialize(reader, typeof(string));
+ //if (dims is null)
+ //{
+ // throw new ValueError("Cannot deserialize 'null' to `Activation`.");
+ //}
+ //return new Shape((long[])(dims!));
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
new file mode 100644
index 00000000..4e190605
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
@@ -0,0 +1,48 @@
+using Newtonsoft.Json.Linq;
+using Newtonsoft.Json;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.Common
+{
+ public class CustomizedAxisJsonConverter : JsonConverter
+ {
+ public override bool CanConvert(Type objectType)
+ {
+ return objectType == typeof(Axis);
+ }
+
+ public override bool CanRead => true;
+
+ public override bool CanWrite => true;
+
+ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
+ {
+ if (value is null)
+ {
+ var token = JToken.FromObject(new int[] { });
+ token.WriteTo(writer);
+ }
+ else if (value is not Axis)
+ {
+ throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}.");
+ }
+ else
+ {
+ var token = JToken.FromObject((value as Axis)!.axis);
+ token.WriteTo(writer);
+ }
+ }
+
+ public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
+ {
+ var axis = serializer.Deserialize(reader, typeof(long[]));
+ if (axis is null)
+ {
+ throw new ValueError("Cannot deserialize 'null' to `Axis`.");
+ }
+ return new Axis((int[])(axis!));
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
new file mode 100644
index 00000000..1ad19fc8
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
@@ -0,0 +1,73 @@
+using Newtonsoft.Json;
+using Newtonsoft.Json.Converters;
+using Newtonsoft.Json.Linq;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow.Keras.Saving;
+
+namespace Tensorflow.Keras.Common
+{
+ public class CustomizedNodeConfigJsonConverter : JsonConverter
+ {
+ public override bool CanConvert(Type objectType)
+ {
+ return objectType == typeof(NodeConfig);
+ }
+
+ public override bool CanRead => true;
+
+ public override bool CanWrite => true;
+
+ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
+ {
+ if (value is null)
+ {
+ var token = JToken.FromObject(null);
+ token.WriteTo(writer);
+ }
+ else if (value is not NodeConfig)
+ {
+ throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}.");
+ }
+ else
+ {
+ var config = value as NodeConfig;
+ var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex });
+ token.WriteTo(writer);
+ }
+ }
+
+ public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
+ {
+ var values = serializer.Deserialize(reader, typeof(object[])) as object[];
+ if (values is null)
+ {
+ throw new ValueError("Cannot deserialize 'null' to `Shape`.");
+ }
+ if(values.Length != 3)
+ {
+ throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
+ }
+ if (values[0] is not string)
+ {
+ throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`");
+ }
+ if (values[1] is not int)
+ {
+ throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`");
+ }
+ if (values[2] is not int)
+ {
+ throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`");
+ }
+ return new NodeConfig()
+ {
+ Name = values[0] as string,
+ NodeIndex = (int)values[1],
+ TensorIndex = (int)values[2]
+ };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
new file mode 100644
index 00000000..300cb2f2
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
@@ -0,0 +1,67 @@
+using Newtonsoft.Json;
+using Newtonsoft.Json.Converters;
+using Newtonsoft.Json.Linq;
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.Common
+{
+ public class CustomizedShapeJsonConverter: JsonConverter
+ {
+ public override bool CanConvert(Type objectType)
+ {
+ return objectType == typeof(Shape);
+ }
+
+ public override bool CanRead => true;
+
+ public override bool CanWrite => true;
+
+ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
+ {
+ if(value is null)
+ {
+ var token = JToken.FromObject(null);
+ token.WriteTo(writer);
+ }
+ else if(value is not Shape)
+ {
+ throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}.");
+ }
+ else
+ {
+ var shape = (value as Shape)!;
+ long?[] dims = new long?[shape.ndim];
+ for(int i = 0; i < dims.Length; i++)
+ {
+ if (shape.dims[i] == -1)
+ {
+ dims[i] = null;
+ }
+ else
+ {
+ dims[i] = shape.dims[i];
+ }
+ }
+ var token = JToken.FromObject(dims);
+ token.WriteTo(writer);
+ }
+ }
+
+ public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
+ {
+ var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
+ if(dims is null)
+ {
+ throw new ValueError("Cannot deserialize 'null' to `Shape`.");
+ }
+ long[] convertedDims = new long[dims.Length];
+ for(int i = 0; i < dims.Length; i++)
+ {
+ convertedDims[i] = dims[i] ?? (-1);
+ }
+ return new Shape(convertedDims);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
index 7280594b..6743935c 100644
--- a/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
+++ b/src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
@@ -16,23 +16,27 @@
using System.Collections.Generic;
using System.Linq;
+using Tensorflow.Keras.Saving;
namespace Tensorflow.Keras.Engine
{
///
/// Specifies the ndim, dtype and shape of every input to a layer.
///
- public class InputSpec
+ public class InputSpec: IKerasConfigable
{
public int? ndim;
+ public int? max_ndim;
public int? min_ndim;
Dictionary axes;
Shape shape;
+ TF_DataType dtype;
public int[] AllAxisDim;
public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
int? ndim = null,
int? min_ndim = null,
+ int? max_ndim = null,
Dictionary axes = null,
Shape shape = null)
{
@@ -41,7 +45,9 @@ namespace Tensorflow.Keras.Engine
axes = new Dictionary();
this.axes = axes;
this.min_ndim = min_ndim;
+ this.max_ndim = max_ndim;
this.shape = shape;
+ this.dtype = dtype;
if (ndim == null && shape != null)
this.ndim = shape.ndim;
@@ -49,7 +55,30 @@ namespace Tensorflow.Keras.Engine
AllAxisDim = axes.Select(x => x.Value).ToArray();
}
+ public IKerasConfig get_config()
+ {
+ return new Config()
+ {
+ DType = dtype == TF_DataType.DtInvalid ? null : dtype,
+ Shape = shape,
+ Ndim = ndim,
+ MinNdim = min_ndim,
+ MaxNdim = max_ndim,
+ Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value)
+ };
+ }
+
public override string ToString()
=> $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}";
+
+ public class Config: IKerasConfig
+ {
+ public TF_DataType? DType { get; set; }
+ public Shape Shape { get; set; }
+ public int? Ndim { get; set; }
+ public int? MinNdim { get;set; }
+ public int? MaxNdim { get;set; }
+ public IDictionary Axes { get; set; }
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
index f1ca5632..ebf3358d 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
@@ -1,11 +1,12 @@
using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Saving;
using Tensorflow.Training;
namespace Tensorflow.Keras
{
- public interface ILayer: ITrackable
+ public interface ILayer: IWithTrackable, IKerasConfigable
{
string Name { get; }
bool Trainable { get; }
@@ -19,8 +20,8 @@ namespace Tensorflow.Keras
List NonTrainableWeights { get; }
Shape OutputShape { get; }
Shape BatchInputShape { get; }
+ TensorShapeConfig BuildInputShape { get; }
TF_DataType DType { get; }
int count_params();
- LayerArgs get_config();
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs
new file mode 100644
index 00000000..1217e1e5
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs
@@ -0,0 +1,15 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Keras.Saving
+{
+ public interface IKerasConfig
+ {
+ }
+
+ public interface IKerasConfigable
+ {
+ IKerasConfig get_config();
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs
index b8b8cab4..4ce290c8 100644
--- a/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs
+++ b/src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs
@@ -1,4 +1,5 @@
-using System;
+using Newtonsoft.Json;
+using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.ArgsDefinition;
@@ -6,11 +7,15 @@ using Tensorflow.Keras.Engine;
namespace Tensorflow.Keras.Saving
{
- public class LayerConfig
+ public class LayerConfig: IKerasConfig
{
+ [JsonProperty("name")]
public string Name { get; set; }
+ [JsonProperty("class_name")]
public string ClassName { get; set; }
+ [JsonProperty("config")]
public LayerArgs Config { get; set; }
+ [JsonProperty("inbound_nodes")]
public List InboundNodes { get; set; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
index abfb235b..cac19180 100644
--- a/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
+++ b/src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
@@ -1,15 +1,20 @@
-using System;
+using Newtonsoft.Json;
+using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Engine;
namespace Tensorflow.Keras.Saving
{
- public class ModelConfig
+ public class ModelConfig : IKerasConfig
{
+ [JsonProperty("name")]
public string Name { get; set; }
+ [JsonProperty("layers")]
public List Layers { get; set; }
+ [JsonProperty("input_layers")]
public List InputLayers { get; set; }
+ [JsonProperty("output_layers")]
public List OutputLayers { get; set; }
public override string ToString()
diff --git a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
index 3132248e..20e2fef5 100644
--- a/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
+++ b/src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
@@ -1,10 +1,13 @@
-using System;
+using Newtonsoft.Json;
+using System;
using System.Collections.Generic;
using System.Text;
+using Tensorflow.Keras.Common;
namespace Tensorflow.Keras.Saving
{
- public class NodeConfig
+ [JsonConverter(typeof(CustomizedNodeConfigJsonConverter))]
+ public class NodeConfig : IKerasConfig
{
public string Name { get; set; }
public int NodeIndex { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs b/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs
new file mode 100644
index 00000000..7abcfde2
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs
@@ -0,0 +1,21 @@
+using Newtonsoft.Json;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Tensorflow.Keras.Saving
+{
+ public class TensorShapeConfig
+ {
+ [JsonProperty("class_name")]
+ public string ClassName { get; set; } = "TensorShape";
+ [JsonProperty("items")]
+ public long?[] Items { get; set; }
+
+ public static implicit operator Shape(TensorShapeConfig shape)
+ => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray());
+
+ public static implicit operator TensorShapeConfig(Shape shape)
+ => new TensorShapeConfig() { Items = shape.dims.Select(x => x == -1 ? null : x).ToArray() };
+ }
+}
diff --git a/src/TensorFlowNET.Core/NumPy/Axis.cs b/src/TensorFlowNET.Core/NumPy/Axis.cs
index 6c7189df..709ca9b2 100644
--- a/src/TensorFlowNET.Core/NumPy/Axis.cs
+++ b/src/TensorFlowNET.Core/NumPy/Axis.cs
@@ -14,20 +14,29 @@
limitations under the License.
******************************************************************************/
+using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using Tensorflow.Keras.Common;
namespace Tensorflow
{
- public record Axis(params int[] axis)
+ [JsonConverter(typeof(CustomizedAxisJsonConverter))]
+ public class Axis
{
+ public int[] axis { get; set; }
public int size => axis == null ? -1 : axis.Length;
public bool IsScalar { get; init; }
public int this[int index] => axis[index];
+ public Axis(params int[] axis)
+ {
+ this.axis = axis;
+ }
+
public static implicit operator int[]?(Axis axis)
=> axis?.axis;
diff --git a/src/TensorFlowNET.Core/Numpy/Shape.cs b/src/TensorFlowNET.Core/Numpy/Shape.cs
index bc79fefc..ecf73586 100644
--- a/src/TensorFlowNET.Core/Numpy/Shape.cs
+++ b/src/TensorFlowNET.Core/Numpy/Shape.cs
@@ -14,14 +14,17 @@
limitations under the License.
******************************************************************************/
+using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using Tensorflow.Keras.Common;
using Tensorflow.NumPy;
namespace Tensorflow
{
+ [JsonConverter(typeof(CustomizedShapeJsonConverter))]
public class Shape
{
public int ndim => _dims == null ? -1 : _dims.Length;
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
index fdcb5aff..e7e9955c 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/
+using System.Collections.Generic;
+
namespace Tensorflow.Operations.Initializers
{
public class Constant : IInitializer
@@ -22,11 +24,19 @@ namespace Tensorflow.Operations.Initializers
T value;
bool _verify_shape;
+ private readonly Dictionary _config;
+
+ public string ClassName => "Constant";
+ public IDictionary Config => _config;
+
public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false)
{
this.value = value;
this.dtype = dtype;
_verify_shape = verify_shape;
+
+ _config = new Dictionary();
+ _config["value"] = this.value;
}
public Tensor Apply(InitializerArgs args)
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
index d97d8830..def1cb7a 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
@@ -14,10 +14,17 @@
limitations under the License.
******************************************************************************/
+using System.Collections.Generic;
+
namespace Tensorflow.Operations.Initializers
{
public class GlorotUniform : VarianceScaling
{
+ private readonly Dictionary _config;
+
+ public override string ClassName => "GlorotUniform";
+ public override IDictionary Config => _config;
+
public GlorotUniform(float scale = 1.0f,
string mode = "FAN_AVG",
bool uniform = true,
@@ -28,7 +35,8 @@ namespace Tensorflow.Operations.Initializers
seed: seed,
dtype: dtype)
{
-
+ _config = new Dictionary();
+ _config["seed"] = _seed;
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
index 50d4d503..9748b100 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
@@ -14,10 +14,17 @@
limitations under the License.
******************************************************************************/
+using Newtonsoft.Json;
+using System.Collections.Generic;
+
namespace Tensorflow
{
public interface IInitializer
{
+ [JsonProperty("class_name")]
+ string ClassName { get; }
+ [JsonProperty("config")]
+ IDictionary Config { get; }
Tensor Apply(InitializerArgs args);
}
}
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
index 02d3c93b..3077a1e0 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
@@ -14,12 +14,19 @@
limitations under the License.
******************************************************************************/
+using System.Collections.Generic;
+
namespace Tensorflow.Operations.Initializers
{
public class Ones : IInitializer
{
private TF_DataType dtype;
+ private readonly Dictionary _config;
+
+ public string ClassName => "Ones";
+ public IDictionary Config => new Dictionary();
+
public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.dtype = dtype;
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
index 254a7ee7..cdc1c3ed 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
@@ -1,9 +1,14 @@
using System;
+using System.Collections.Generic;
namespace Tensorflow.Operations.Initializers
{
public class Orthogonal : IInitializer
{
+ private readonly Dictionary _config;
+
+ public string ClassName => "Orthogonal";
+ public IDictionary Config => throw new NotImplementedException();
public Tensor Apply(InitializerArgs args)
{
throw new NotImplementedException();
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
index 029b311b..21fa7e2b 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/
+using System.Collections.Generic;
+
namespace Tensorflow.Operations.Initializers
{
public class RandomNormal : IInitializer
@@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers
private int? seed;
private TF_DataType dtype;
+ private readonly Dictionary _config;
+
+ public string ClassName => "RandomNormal";
+ public IDictionary Config => _config;
+
public RandomNormal(float mean = 0.0f,
float stddev = 0.05f,
int? seed = null,
@@ -32,6 +39,11 @@ namespace Tensorflow.Operations.Initializers
this.stddev = stddev;
this.seed = seed;
this.dtype = dtype;
+
+ _config = new Dictionary();
+ _config["mean"] = this.mean;
+ _config["stddev"] = this.stddev;
+ _config["seed"] = this.seed;
}
public Tensor Apply(InitializerArgs args)
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
index a49d5921..87404708 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/
+using System.Collections.Generic;
+
namespace Tensorflow.Operations.Initializers
{
public class RandomUniform : IInitializer
@@ -23,12 +25,22 @@ namespace Tensorflow.Operations.Initializers
private float maxval;
private TF_DataType dtype;
+ private readonly Dictionary _config;
+
+ public string ClassName => "RandomUniform";
+ public IDictionary Config => _config;
+
public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null)
{
this.dtype = dtype;
this.minval = minval;
this.maxval = maxval;
this.seed = seed;
+
+ _config = new Dictionary();
+ _config["minval"] = this.minval;
+ _config["maxval"] = this.maxval;
+ _config["seed"] = this.seed;
}
public Tensor Apply(InitializerArgs args)
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
index 048c11e7..c1c3e999 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/
+using System.Collections.Generic;
+
namespace Tensorflow.Operations.Initializers
{
public class TruncatedNormal : IInitializer
@@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers
private int? seed;
private TF_DataType dtype;
+ private readonly Dictionary _config;
+
+ public string ClassName => "TruncatedNormal";
+ public IDictionary Config => _config;
+
public TruncatedNormal(float mean = 0.0f,
float stddev = 1.0f,
int? seed = null,
@@ -32,6 +39,10 @@ namespace Tensorflow.Operations.Initializers
this.stddev = stddev;
this.seed = seed;
this.dtype = dtype;
+ _config = new Dictionary();
+ _config["mean"] = this.mean;
+ _config["stddev"] = this.stddev;
+ _config["seed"] = this.seed;
}
public Tensor Apply(InitializerArgs args)
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
index d313f4c9..f104e8e8 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
@@ -15,7 +15,9 @@
******************************************************************************/
using System;
+using System.Collections.Generic;
using System.Linq;
+using System.Linq.Expressions;
namespace Tensorflow.Operations.Initializers
{
@@ -30,6 +32,11 @@ namespace Tensorflow.Operations.Initializers
protected int? _seed;
protected TF_DataType _dtype;
protected bool _uniform;
+ private readonly Dictionary _config;
+
+ public virtual string ClassName => "VarianceScaling";
+
+ public virtual IDictionary Config => _config;
public VarianceScaling(float factor = 2.0f,
string mode = "FAN_IN",
@@ -50,6 +57,12 @@ namespace Tensorflow.Operations.Initializers
_seed = seed;
_dtype = dtype;
_uniform = uniform;
+
+ _config = new();
+ _config["scale"] = _scale;
+ _config["mode"] = _mode;
+ _config["distribution"] = _distribution;
+ _config["seed"] = _seed;
}
public Tensor Apply(InitializerArgs args)
diff --git a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
index 5d045292..c4ed25a1 100644
--- a/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
+++ b/src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
@@ -14,6 +14,8 @@
limitations under the License.
******************************************************************************/
+using System.Collections.Generic;
+
namespace Tensorflow.Operations.Initializers
{
public class Zeros : IInitializer
@@ -21,6 +23,9 @@ namespace Tensorflow.Operations.Initializers
Shape shape;
TF_DataType dtype;
+ public string ClassName => "Zeros";
+ public IDictionary Config => new Dictionary();
+
public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT)
{
this.shape = shape;
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
index 734f2608..c29ed47b 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
@@ -20,6 +20,7 @@ using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.Saving;
using Tensorflow.Operations;
using Tensorflow.Train;
using Tensorflow.Util;
@@ -76,6 +77,8 @@ namespace Tensorflow
public Shape BatchInputShape => throw new NotImplementedException();
+ public TensorShapeConfig BuildInputShape => throw new NotImplementedException();
+
public TF_DataType DType => throw new NotImplementedException();
protected bool built = false;
public bool Built => built;
@@ -144,7 +147,7 @@ namespace Tensorflow
throw new NotImplementedException();
}
- public LayerArgs get_config()
+ public IKerasConfig get_config()
{
throw new NotImplementedException();
}
diff --git a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
index 0ebe61d0..7068ed47 100644
--- a/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
+++ b/src/TensorFlowNET.Core/Tensorflow.Binding.csproj
@@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io
+
diff --git a/src/TensorFlowNET.Core/Tensors/dtypes.cs b/src/TensorFlowNET.Core/Tensors/dtypes.cs
index 372ac676..deeb9e4b 100644
--- a/src/TensorFlowNET.Core/Tensors/dtypes.cs
+++ b/src/TensorFlowNET.Core/Tensors/dtypes.cs
@@ -202,6 +202,24 @@ namespace Tensorflow
_ => type.ToString()
};
+ public static string as_python_name(this TF_DataType type)
+ => type switch
+ {
+ TF_DataType.TF_STRING => "str",
+ TF_DataType.TF_UINT8 => "uint8",
+ TF_DataType.TF_INT8 => "int8",
+ TF_DataType.TF_UINT32 => "uint32",
+ TF_DataType.TF_INT32 => "int32",
+ TF_DataType.TF_UINT64 => "uint64",
+ TF_DataType.TF_INT64 => "int64",
+ TF_DataType.TF_FLOAT => "float32",
+ TF_DataType.TF_DOUBLE => "float64",
+ TF_DataType.TF_BOOL => "bool",
+ TF_DataType.TF_RESOURCE => "resource",
+ TF_DataType.TF_VARIANT => "variant",
+ _ => type.ToString()
+ };
+
public static int get_datatype_size(this TF_DataType type)
=> type.as_base_dtype() switch
{
diff --git a/src/TensorFlowNET.Core/Training/ITrackable.cs b/src/TensorFlowNET.Core/Training/IWithTrackable.cs
similarity index 82%
rename from src/TensorFlowNET.Core/Training/ITrackable.cs
rename to src/TensorFlowNET.Core/Training/IWithTrackable.cs
index e4ef2c8f..87eda879 100644
--- a/src/TensorFlowNET.Core/Training/ITrackable.cs
+++ b/src/TensorFlowNET.Core/Training/IWithTrackable.cs
@@ -5,7 +5,7 @@ using Tensorflow.Train;
namespace Tensorflow.Training
{
- public interface ITrackable
+ public interface IWithTrackable
{
Trackable GetTrackable();
}
diff --git a/src/TensorFlowNET.Core/Training/Trackable.cs b/src/TensorFlowNET.Core/Training/Trackable.cs
index 434d51b6..132571f2 100644
--- a/src/TensorFlowNET.Core/Training/Trackable.cs
+++ b/src/TensorFlowNET.Core/Training/Trackable.cs
@@ -26,7 +26,7 @@ using static Tensorflow.Binding;
namespace Tensorflow.Train
{
- public abstract class Trackable: ITrackable
+ public abstract class Trackable: IWithTrackable
{
///
/// Corresponding to tensorflow/python/trackable/constants.py
diff --git a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
index a221444b..3aeb3200 100644
--- a/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
+++ b/src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine
{
public partial class Functional
{
- public ModelConfig get_config()
+ public override IKerasConfig get_config()
{
return get_network_config();
}
@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine
{
Name = name
};
-
+
var node_conversion_map = new Dictionary();
foreach (var layer in _self_tracked_trackables)
{
@@ -42,23 +42,26 @@ namespace Tensorflow.Keras.Engine
}
var layer_configs = new List();
- foreach (var layer in _self_tracked_trackables)
+ using (SharedObjectSavingScope.Enter())
{
- var filtered_inbound_nodes = new List();
- foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
+ foreach (var layer in _self_tracked_trackables)
{
- var node_key = _make_node_key(layer.Name, original_node_index);
- if (NetworkNodes.Contains(node_key) && !node.is_input)
+ var filtered_inbound_nodes = new List();
+ foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
{
- var node_data = node.serialize(_make_node_key, node_conversion_map);
- filtered_inbound_nodes.append(node_data);
+ var node_key = _make_node_key(layer.Name, original_node_index);
+ if (NetworkNodes.Contains(node_key) && !node.is_input)
+ {
+ var node_data = node.serialize(_make_node_key, node_conversion_map);
+ filtered_inbound_nodes.append(node_data);
+ }
}
- }
- var layer_config = generic_utils.serialize_layer_to_config(layer);
- layer_config.Name = layer.Name;
- layer_config.InboundNodes = filtered_inbound_nodes;
- layer_configs.Add(layer_config);
+ var layer_config = generic_utils.serialize_layer_to_config(layer);
+ layer_config.Name = layer.Name;
+ layer_config.InboundNodes = filtered_inbound_nodes;
+ layer_configs.Add(layer_config);
+ }
}
config.Layers = layer_configs;
diff --git a/src/TensorFlowNET.Keras/Engine/Functional.cs b/src/TensorFlowNET.Keras/Engine/Functional.cs
index 7c8812ad..44eaef53 100644
--- a/src/TensorFlowNET.Keras/Engine/Functional.cs
+++ b/src/TensorFlowNET.Keras/Engine/Functional.cs
@@ -70,6 +70,7 @@ namespace Tensorflow.Keras.Engine
this.inputs = inputs;
this.outputs = outputs;
built = true;
+ _buildInputShape = inputs.shape;
if (outputs.Any(x => x.KerasHistory == null))
base_layer_utils.create_keras_history(outputs);
@@ -357,5 +358,22 @@ namespace Tensorflow.Keras.Engine
return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache))
.ToDictionary(x => x.Key, x => x.Value);
}
+
+ protected override void _init_set_name(string name, bool zero_based = true)
+ {
+ if (string.IsNullOrEmpty(name))
+ {
+ string class_name = GetType().Name;
+ if (this.GetType() == typeof(Functional))
+ {
+ class_name = "Model";
+ }
+ this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based);
+ }
+ else
+ {
+ this.name = name;
+ }
+ }
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs
index a2f92ba8..31b37d68 100644
--- a/src/TensorFlowNET.Keras/Engine/Layer.cs
+++ b/src/TensorFlowNET.Keras/Engine/Layer.cs
@@ -61,6 +61,7 @@ namespace Tensorflow.Keras.Engine
/// Provides information about which inputs are compatible with the layer.
///
protected InputSpec inputSpec;
+ public InputSpec InputSpec => inputSpec;
bool dynamic = true;
public bool SupportsMasking { get; set; }
protected List _trainable_weights;
@@ -79,6 +80,8 @@ namespace Tensorflow.Keras.Engine
protected bool computePreviousMask;
protected List updates;
public Shape BatchInputShape => args.BatchInputShape;
+ protected TensorShapeConfig _buildInputShape = null;
+ public TensorShapeConfig BuildInputShape => _buildInputShape;
List inboundNodes;
public List InboundNodes => inboundNodes;
@@ -223,6 +226,7 @@ namespace Tensorflow.Keras.Engine
public virtual void build(Shape input_shape)
{
+ _buildInputShape = input_shape;
built = true;
}
@@ -310,7 +314,7 @@ namespace Tensorflow.Keras.Engine
public List Variables => weights;
- public virtual LayerArgs get_config()
+ public virtual IKerasConfig get_config()
=> args;
}
}
diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs
index 59b205e4..85da920e 100644
--- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs
+++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs
@@ -1,6 +1,7 @@
using System.Collections.Generic;
using Tensorflow.Functions;
using Tensorflow.Keras.Metrics;
+using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.ModelSaving;
@@ -30,7 +31,10 @@ namespace Tensorflow.Keras.Engine
}
else
{
- KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces);
+ using (SharedObjectSavingScope.Enter())
+ {
+ KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces);
+ }
}
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
index 6e790a26..45f64720 100644
--- a/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
+++ b/src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
@@ -25,6 +25,7 @@ namespace Tensorflow.Keras.Layers {
{
throw new ValueError("Alpha must be a number greater than 0.");
}
+ _buildInputShape = input_shape;
built = true;
}
diff --git a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
index aba175de..2fd2caee 100644
--- a/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
+++ b/src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
@@ -14,6 +14,7 @@ namespace Tensorflow.Keras.Layers {
}
public override void build(Shape input_shape)
{
+ _buildInputShape = input_shape;
built = true;
}
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
diff --git a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
index b12d7dee..1ef8d0e5 100644
--- a/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
+++ b/src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
@@ -16,10 +16,11 @@ namespace Tensorflow.Keras.Layers {
// SELU has no arguments
}
public override void build(Shape input_shape) {
- if ( alpha < 0f ) {
- throw new ValueError("Alpha must be a number greater than 0.");
- }
- built = true;
+ if ( alpha < 0f ) {
+ throw new ValueError("Alpha must be a number greater than 0.");
+ }
+ _buildInputShape = input_shape;
+ built = true;
}
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
Tensor output = inputs;
diff --git a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs
index 6f6dd7e8..c5131630 100644
--- a/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs
+++ b/src/TensorFlowNET.Keras/Layers/Attention/Attention.cs
@@ -4,6 +4,7 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Saving;
namespace Tensorflow.Keras.Layers
{
@@ -146,7 +147,7 @@ namespace Tensorflow.Keras.Layers
return scores;
}
- public override LayerArgs get_config() => this.args;
+ public override IKerasConfig get_config() => this.args;
//var config = new Dictionary