| @@ -1,9 +1,18 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.ArgsDefinition { | namespace Tensorflow.Keras.ArgsDefinition { | ||||
| public class SoftmaxArgs : LayerArgs { | |||||
| public Axis axis { get; set; } = -1; | |||||
| } | |||||
| public class SoftmaxArgs : LayerArgs | |||||
| { | |||||
| [JsonProperty("axis")] | |||||
| public Axis axis { get; set; } = -1; | |||||
| [JsonProperty("name")] | |||||
| public override string Name { get => base.Name; set => base.Name = value; } | |||||
| [JsonProperty("trainable")] | |||||
| public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | |||||
| [JsonProperty("dtype")] | |||||
| public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
| } | |||||
| } | } | ||||
| @@ -0,0 +1,19 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | |||||
| public class AutoSerializeLayerArgs: LayerArgs | |||||
| { | |||||
| [JsonProperty("name")] | |||||
| public override string Name { get => base.Name; set => base.Name = value; } | |||||
| [JsonProperty("dtype")] | |||||
| public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
| [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | |||||
| public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||||
| [JsonProperty("trainable")] | |||||
| public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | |||||
| } | |||||
| } | |||||
| @@ -1,13 +1,18 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Xml.Linq; | |||||
| using Tensorflow.Operations.Initializers; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| // TODO: `activity_regularizer` | |||||
| public class DenseArgs : LayerArgs | public class DenseArgs : LayerArgs | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Positive integer, dimensionality of the output space. | /// Positive integer, dimensionality of the output space. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("units")] | |||||
| public int Units { get; set; } | public int Units { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -15,39 +20,74 @@ namespace Tensorflow.Keras.ArgsDefinition | |||||
| /// </summary> | /// </summary> | ||||
| public Activation Activation { get; set; } | public Activation Activation { get; set; } | ||||
| private string _activationName; | |||||
| [JsonProperty("activation")] | |||||
| public string ActivationName | |||||
| { | |||||
| get | |||||
| { | |||||
| if (string.IsNullOrEmpty(_activationName)) | |||||
| { | |||||
| return Activation.Method.Name; | |||||
| } | |||||
| else | |||||
| { | |||||
| return _activationName; | |||||
| } | |||||
| } | |||||
| set | |||||
| { | |||||
| _activationName = value; | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Whether the layer uses a bias vector. | /// Whether the layer uses a bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("use_bias")] | |||||
| public bool UseBias { get; set; } = true; | public bool UseBias { get; set; } = true; | ||||
| /// <summary> | /// <summary> | ||||
| /// Initializer for the `kernel` weights matrix. | /// Initializer for the `kernel` weights matrix. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("kernel_initializer")] | |||||
| public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; | ||||
| /// <summary> | /// <summary> | ||||
| /// Initializer for the bias vector. | /// Initializer for the bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("bias_initializer")] | |||||
| public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; | ||||
| /// <summary> | /// <summary> | ||||
| /// Regularizer function applied to the `kernel` weights matrix. | /// Regularizer function applied to the `kernel` weights matrix. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("kernel_regularizer")] | |||||
| public IRegularizer KernelRegularizer { get; set; } | public IRegularizer KernelRegularizer { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Regularizer function applied to the bias vector. | /// Regularizer function applied to the bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("bias_regularizer")] | |||||
| public IRegularizer BiasRegularizer { get; set; } | public IRegularizer BiasRegularizer { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Constraint function applied to the `kernel` weights matrix. | /// Constraint function applied to the `kernel` weights matrix. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("kernel_constraint")] | |||||
| public Action KernelConstraint { get; set; } | public Action KernelConstraint { get; set; } | ||||
| /// <summary> | /// <summary> | ||||
| /// Constraint function applied to the bias vector. | /// Constraint function applied to the bias vector. | ||||
| /// </summary> | /// </summary> | ||||
| [JsonProperty("bias_constraint")] | |||||
| public Action BiasConstraint { get; set; } | public Action BiasConstraint { get; set; } | ||||
| [JsonProperty("name")] | |||||
| public override string Name { get => base.Name; set => base.Name = value; } | |||||
| [JsonProperty("dtype")] | |||||
| public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
| [JsonProperty("trainable")] | |||||
| public override bool Trainable { get => base.Trainable; set => base.Trainable = value; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,9 +1,22 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Serialization; | |||||
| using Tensorflow.Keras.Common; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class InputLayerArgs : LayerArgs | public class InputLayerArgs : LayerArgs | ||||
| { | { | ||||
| [JsonIgnore] | |||||
| public Tensor InputTensor { get; set; } | public Tensor InputTensor { get; set; } | ||||
| public bool Sparse { get; set; } | |||||
| [JsonProperty("sparse")] | |||||
| public virtual bool Sparse { get; set; } | |||||
| [JsonProperty("ragged")] | |||||
| public bool Ragged { get; set; } | public bool Ragged { get; set; } | ||||
| [JsonProperty("name")] | |||||
| public override string Name { get => base.Name; set => base.Name = value; } | |||||
| [JsonProperty("dtype")] | |||||
| public override TF_DataType DType { get => base.DType; set => base.DType = value; } | |||||
| [JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)] | |||||
| public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,8 +1,9 @@ | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class DataAdapterArgs | |||||
| public class DataAdapterArgs: IKerasConfig | |||||
| { | { | ||||
| public Tensor X { get; set; } | public Tensor X { get; set; } | ||||
| public Tensor Y { get; set; } | public Tensor Y { get; set; } | ||||
| @@ -1,8 +1,9 @@ | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | namespace Tensorflow.Keras.ArgsDefinition | ||||
| { | { | ||||
| public class DataHandlerArgs | |||||
| public class DataHandlerArgs: IKerasConfig | |||||
| { | { | ||||
| public Tensor X { get; set; } | public Tensor X { get; set; } | ||||
| public Tensor Y { get; set; } | public Tensor Y { get; set; } | ||||
| @@ -1,51 +1,54 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class LayerArgs | |||||
| [JsonObject(MemberSerialization.OptIn)] | |||||
| public class LayerArgs: IKerasConfig | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Indicates whether the layer's weights are updated during training | /// Indicates whether the layer's weights are updated during training | ||||
| /// and whether the layer's updates are run during training. | /// and whether the layer's updates are run during training. | ||||
| /// </summary> | /// </summary> | ||||
| public bool Trainable { get; set; } = true; | |||||
| public string Name { get; set; } | |||||
| public virtual bool Trainable { get; set; } = true; | |||||
| public virtual string Name { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Only applicable to input layers. | /// Only applicable to input layers. | ||||
| /// </summary> | /// </summary> | ||||
| public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; | |||||
| public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT; | |||||
| /// <summary> | /// <summary> | ||||
| /// Whether the `call` method can be used to build a TF graph without issues. | /// Whether the `call` method can be used to build a TF graph without issues. | ||||
| /// This attribute has no effect if the model is created using the Functional | /// This attribute has no effect if the model is created using the Functional | ||||
| /// API. Instead, `model.dynamic` is determined based on the internal layers. | /// API. Instead, `model.dynamic` is determined based on the internal layers. | ||||
| /// </summary> | /// </summary> | ||||
| public bool Dynamic { get; set; } = false; | |||||
| public virtual bool Dynamic { get; set; } = false; | |||||
| /// <summary> | /// <summary> | ||||
| /// Only applicable to input layers. | /// Only applicable to input layers. | ||||
| /// </summary> | /// </summary> | ||||
| public Shape InputShape { get; set; } | |||||
| public virtual Shape InputShape { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Only applicable to input layers. | /// Only applicable to input layers. | ||||
| /// </summary> | /// </summary> | ||||
| public Shape BatchInputShape { get; set; } | |||||
| public virtual Shape BatchInputShape { get; set; } | |||||
| public int BatchSize { get; set; } = -1; | |||||
| public virtual int BatchSize { get; set; } = -1; | |||||
| /// <summary> | /// <summary> | ||||
| /// Initial weight values. | /// Initial weight values. | ||||
| /// </summary> | /// </summary> | ||||
| public float[] Weights { get; set; } | |||||
| public virtual float[] Weights { get; set; } | |||||
| /// <summary> | /// <summary> | ||||
| /// Regularizer function applied to the output of the layer(its "activation"). | /// Regularizer function applied to the output of the layer(its "activation"). | ||||
| /// </summary> | /// </summary> | ||||
| public IRegularizer ActivityRegularizer { get; set; } | |||||
| public virtual IRegularizer ActivityRegularizer { get; set; } | |||||
| public bool Autocast { get; set; } | |||||
| public virtual bool Autocast { get; set; } | |||||
| public bool IsFromConfig { get; set; } | |||||
| public virtual bool IsFromConfig { get; set; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,8 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class NodeArgs | |||||
| public class NodeArgs: IKerasConfig | |||||
| { | { | ||||
| public ILayer[] InboundLayers { get; set; } | public ILayer[] InboundLayers { get; set; } | ||||
| public int[] NodeIndices { get; set; } | public int[] NodeIndices { get; set; } | ||||
| @@ -1,6 +1,8 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class OptimizerV2Args | |||||
| public class OptimizerV2Args: IKerasConfig | |||||
| { | { | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| public float LearningRate { get; set; } = 0.001f; | public float LearningRate { get; set; } = 0.001f; | ||||
| @@ -1,7 +1,10 @@ | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| using Newtonsoft.Json; | |||||
| namespace Tensorflow.Keras.ArgsDefinition | |||||
| { | { | ||||
| public class FlattenArgs : LayerArgs | |||||
| public class FlattenArgs : AutoSerializeLayerArgs | |||||
| { | { | ||||
| [JsonProperty("data_format")] | |||||
| public string DataFormat { get; set; } | public string DataFormat { get; set; } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,50 @@ | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Converters; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Common | |||||
| { | |||||
| public class CustomizedActivationJsonConverter : JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(Activation); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| if (value is null) | |||||
| { | |||||
| var token = JToken.FromObject(""); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| else if (value is not Activation) | |||||
| { | |||||
| throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}."); | |||||
| } | |||||
| else | |||||
| { | |||||
| var token = JToken.FromObject((value as Activation)!.GetType().Name); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| //var dims = serializer.Deserialize(reader, typeof(string)); | |||||
| //if (dims is null) | |||||
| //{ | |||||
| // throw new ValueError("Cannot deserialize 'null' to `Activation`."); | |||||
| //} | |||||
| //return new Shape((long[])(dims!)); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,48 @@ | |||||
| using Newtonsoft.Json.Linq; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Common | |||||
| { | |||||
| public class CustomizedAxisJsonConverter : JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(Axis); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| if (value is null) | |||||
| { | |||||
| var token = JToken.FromObject(new int[] { }); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| else if (value is not Axis) | |||||
| { | |||||
| throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}."); | |||||
| } | |||||
| else | |||||
| { | |||||
| var token = JToken.FromObject((value as Axis)!.axis); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| var axis = serializer.Deserialize(reader, typeof(long[])); | |||||
| if (axis is null) | |||||
| { | |||||
| throw new ValueError("Cannot deserialize 'null' to `Axis`."); | |||||
| } | |||||
| return new Axis((int[])(axis!)); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,73 @@ | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Converters; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Common | |||||
| { | |||||
| public class CustomizedNodeConfigJsonConverter : JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(NodeConfig); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| if (value is null) | |||||
| { | |||||
| var token = JToken.FromObject(null); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| else if (value is not NodeConfig) | |||||
| { | |||||
| throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}."); | |||||
| } | |||||
| else | |||||
| { | |||||
| var config = value as NodeConfig; | |||||
| var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex }); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| var values = serializer.Deserialize(reader, typeof(object[])) as object[]; | |||||
| if (values is null) | |||||
| { | |||||
| throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||||
| } | |||||
| if(values.Length != 3) | |||||
| { | |||||
| throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`."); | |||||
| } | |||||
| if (values[0] is not string) | |||||
| { | |||||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`"); | |||||
| } | |||||
| if (values[1] is not int) | |||||
| { | |||||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`"); | |||||
| } | |||||
| if (values[2] is not int) | |||||
| { | |||||
| throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`"); | |||||
| } | |||||
| return new NodeConfig() | |||||
| { | |||||
| Name = values[0] as string, | |||||
| NodeIndex = (int)values[1], | |||||
| TensorIndex = (int)values[2] | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,67 @@ | |||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Converters; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Common | |||||
| { | |||||
| public class CustomizedShapeJsonConverter: JsonConverter | |||||
| { | |||||
| public override bool CanConvert(Type objectType) | |||||
| { | |||||
| return objectType == typeof(Shape); | |||||
| } | |||||
| public override bool CanRead => true; | |||||
| public override bool CanWrite => true; | |||||
| public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer) | |||||
| { | |||||
| if(value is null) | |||||
| { | |||||
| var token = JToken.FromObject(null); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| else if(value is not Shape) | |||||
| { | |||||
| throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}."); | |||||
| } | |||||
| else | |||||
| { | |||||
| var shape = (value as Shape)!; | |||||
| long?[] dims = new long?[shape.ndim]; | |||||
| for(int i = 0; i < dims.Length; i++) | |||||
| { | |||||
| if (shape.dims[i] == -1) | |||||
| { | |||||
| dims[i] = null; | |||||
| } | |||||
| else | |||||
| { | |||||
| dims[i] = shape.dims[i]; | |||||
| } | |||||
| } | |||||
| var token = JToken.FromObject(dims); | |||||
| token.WriteTo(writer); | |||||
| } | |||||
| } | |||||
| public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer) | |||||
| { | |||||
| var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[]; | |||||
| if(dims is null) | |||||
| { | |||||
| throw new ValueError("Cannot deserialize 'null' to `Shape`."); | |||||
| } | |||||
| long[] convertedDims = new long[dims.Length]; | |||||
| for(int i = 0; i < dims.Length; i++) | |||||
| { | |||||
| convertedDims[i] = dims[i] ?? (-1); | |||||
| } | |||||
| return new Shape(convertedDims); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -16,23 +16,27 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Engine | namespace Tensorflow.Keras.Engine | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Specifies the ndim, dtype and shape of every input to a layer. | /// Specifies the ndim, dtype and shape of every input to a layer. | ||||
| /// </summary> | /// </summary> | ||||
| public class InputSpec | |||||
| public class InputSpec: IKerasConfigable | |||||
| { | { | ||||
| public int? ndim; | public int? ndim; | ||||
| public int? max_ndim; | |||||
| public int? min_ndim; | public int? min_ndim; | ||||
| Dictionary<int, int> axes; | Dictionary<int, int> axes; | ||||
| Shape shape; | Shape shape; | ||||
| TF_DataType dtype; | |||||
| public int[] AllAxisDim; | public int[] AllAxisDim; | ||||
| public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, | ||||
| int? ndim = null, | int? ndim = null, | ||||
| int? min_ndim = null, | int? min_ndim = null, | ||||
| int? max_ndim = null, | |||||
| Dictionary<int, int> axes = null, | Dictionary<int, int> axes = null, | ||||
| Shape shape = null) | Shape shape = null) | ||||
| { | { | ||||
| @@ -41,7 +45,9 @@ namespace Tensorflow.Keras.Engine | |||||
| axes = new Dictionary<int, int>(); | axes = new Dictionary<int, int>(); | ||||
| this.axes = axes; | this.axes = axes; | ||||
| this.min_ndim = min_ndim; | this.min_ndim = min_ndim; | ||||
| this.max_ndim = max_ndim; | |||||
| this.shape = shape; | this.shape = shape; | ||||
| this.dtype = dtype; | |||||
| if (ndim == null && shape != null) | if (ndim == null && shape != null) | ||||
| this.ndim = shape.ndim; | this.ndim = shape.ndim; | ||||
| @@ -49,7 +55,30 @@ namespace Tensorflow.Keras.Engine | |||||
| AllAxisDim = axes.Select(x => x.Value).ToArray(); | AllAxisDim = axes.Select(x => x.Value).ToArray(); | ||||
| } | } | ||||
| public IKerasConfig get_config() | |||||
| { | |||||
| return new Config() | |||||
| { | |||||
| DType = dtype == TF_DataType.DtInvalid ? null : dtype, | |||||
| Shape = shape, | |||||
| Ndim = ndim, | |||||
| MinNdim = min_ndim, | |||||
| MaxNdim = max_ndim, | |||||
| Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value) | |||||
| }; | |||||
| } | |||||
| public override string ToString() | public override string ToString() | ||||
| => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; | => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; | ||||
| public class Config: IKerasConfig | |||||
| { | |||||
| public TF_DataType? DType { get; set; } | |||||
| public Shape Shape { get; set; } | |||||
| public int? Ndim { get; set; } | |||||
| public int? MinNdim { get;set; } | |||||
| public int? MaxNdim { get;set; } | |||||
| public IDictionary<string, int> Axes { get; set; } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,11 +1,12 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| namespace Tensorflow.Keras | namespace Tensorflow.Keras | ||||
| { | { | ||||
| public interface ILayer: ITrackable | |||||
| public interface ILayer: IWithTrackable, IKerasConfigable | |||||
| { | { | ||||
| string Name { get; } | string Name { get; } | ||||
| bool Trainable { get; } | bool Trainable { get; } | ||||
| @@ -19,8 +20,8 @@ namespace Tensorflow.Keras | |||||
| List<IVariableV1> NonTrainableWeights { get; } | List<IVariableV1> NonTrainableWeights { get; } | ||||
| Shape OutputShape { get; } | Shape OutputShape { get; } | ||||
| Shape BatchInputShape { get; } | Shape BatchInputShape { get; } | ||||
| TensorShapeConfig BuildInputShape { get; } | |||||
| TF_DataType DType { get; } | TF_DataType DType { get; } | ||||
| int count_params(); | int count_params(); | ||||
| LayerArgs get_config(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,15 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Keras.Saving | |||||
| { | |||||
| public interface IKerasConfig | |||||
| { | |||||
| } | |||||
| public interface IKerasConfigable | |||||
| { | |||||
| IKerasConfig get_config(); | |||||
| } | |||||
| } | |||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| @@ -6,11 +7,15 @@ using Tensorflow.Keras.Engine; | |||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| public class LayerConfig | |||||
| public class LayerConfig: IKerasConfig | |||||
| { | { | ||||
| [JsonProperty("name")] | |||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| [JsonProperty("class_name")] | |||||
| public string ClassName { get; set; } | public string ClassName { get; set; } | ||||
| [JsonProperty("config")] | |||||
| public LayerArgs Config { get; set; } | public LayerArgs Config { get; set; } | ||||
| [JsonProperty("inbound_nodes")] | |||||
| public List<NodeConfig> InboundNodes { get; set; } | public List<NodeConfig> InboundNodes { get; set; } | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,15 +1,20 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| public class ModelConfig | |||||
| public class ModelConfig : IKerasConfig | |||||
| { | { | ||||
| [JsonProperty("name")] | |||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| [JsonProperty("layers")] | |||||
| public List<LayerConfig> Layers { get; set; } | public List<LayerConfig> Layers { get; set; } | ||||
| [JsonProperty("input_layers")] | |||||
| public List<NodeConfig> InputLayers { get; set; } | public List<NodeConfig> InputLayers { get; set; } | ||||
| [JsonProperty("output_layers")] | |||||
| public List<NodeConfig> OutputLayers { get; set; } | public List<NodeConfig> OutputLayers { get; set; } | ||||
| public override string ToString() | public override string ToString() | ||||
| @@ -1,10 +1,13 @@ | |||||
| using System; | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Common; | |||||
| namespace Tensorflow.Keras.Saving | namespace Tensorflow.Keras.Saving | ||||
| { | { | ||||
| public class NodeConfig | |||||
| [JsonConverter(typeof(CustomizedNodeConfigJsonConverter))] | |||||
| public class NodeConfig : IKerasConfig | |||||
| { | { | ||||
| public string Name { get; set; } | public string Name { get; set; } | ||||
| public int NodeIndex { get; set; } | public int NodeIndex { get; set; } | ||||
| @@ -0,0 +1,21 @@ | |||||
| using Newtonsoft.Json; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| namespace Tensorflow.Keras.Saving | |||||
| { | |||||
| public class TensorShapeConfig | |||||
| { | |||||
| [JsonProperty("class_name")] | |||||
| public string ClassName { get; set; } = "TensorShape"; | |||||
| [JsonProperty("items")] | |||||
| public long?[] Items { get; set; } | |||||
| public static implicit operator Shape(TensorShapeConfig shape) | |||||
| => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); | |||||
| public static implicit operator TensorShapeConfig(Shape shape) | |||||
| => new TensorShapeConfig() { Items = shape.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() }; | |||||
| } | |||||
| } | |||||
| @@ -14,20 +14,29 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Common; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public record Axis(params int[] axis) | |||||
| [JsonConverter(typeof(CustomizedAxisJsonConverter))] | |||||
| public class Axis | |||||
| { | { | ||||
| public int[] axis { get; set; } | |||||
| public int size => axis == null ? -1 : axis.Length; | public int size => axis == null ? -1 : axis.Length; | ||||
| public bool IsScalar { get; init; } | public bool IsScalar { get; init; } | ||||
| public int this[int index] => axis[index]; | public int this[int index] => axis[index]; | ||||
| public Axis(params int[] axis) | |||||
| { | |||||
| this.axis = axis; | |||||
| } | |||||
| public static implicit operator int[]?(Axis axis) | public static implicit operator int[]?(Axis axis) | ||||
| => axis?.axis; | => axis?.axis; | ||||
| @@ -14,14 +14,17 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Newtonsoft.Json; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Keras.Common; | |||||
| using Tensorflow.NumPy; | using Tensorflow.NumPy; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| [JsonConverter(typeof(CustomizedShapeJsonConverter))] | |||||
| public class Shape | public class Shape | ||||
| { | { | ||||
| public int ndim => _dims == null ? -1 : _dims.Length; | public int ndim => _dims == null ? -1 : _dims.Length; | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class Constant<T> : IInitializer | public class Constant<T> : IInitializer | ||||
| @@ -22,11 +24,19 @@ namespace Tensorflow.Operations.Initializers | |||||
| T value; | T value; | ||||
| bool _verify_shape; | bool _verify_shape; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "Constant"; | |||||
| public IDictionary<string, object> Config => _config; | |||||
| public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) | ||||
| { | { | ||||
| this.value = value; | this.value = value; | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| _verify_shape = verify_shape; | _verify_shape = verify_shape; | ||||
| _config = new Dictionary<string, object>(); | |||||
| _config["value"] = this.value; | |||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| @@ -14,10 +14,17 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class GlorotUniform : VarianceScaling | public class GlorotUniform : VarianceScaling | ||||
| { | { | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public override string ClassName => "GlorotUniform"; | |||||
| public override IDictionary<string, object> Config => _config; | |||||
| public GlorotUniform(float scale = 1.0f, | public GlorotUniform(float scale = 1.0f, | ||||
| string mode = "FAN_AVG", | string mode = "FAN_AVG", | ||||
| bool uniform = true, | bool uniform = true, | ||||
| @@ -28,7 +35,8 @@ namespace Tensorflow.Operations.Initializers | |||||
| seed: seed, | seed: seed, | ||||
| dtype: dtype) | dtype: dtype) | ||||
| { | { | ||||
| _config = new Dictionary<string, object>(); | |||||
| _config["seed"] = _seed; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -14,10 +14,17 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Newtonsoft.Json; | |||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public interface IInitializer | public interface IInitializer | ||||
| { | { | ||||
| [JsonProperty("class_name")] | |||||
| string ClassName { get; } | |||||
| [JsonProperty("config")] | |||||
| IDictionary<string, object> Config { get; } | |||||
| Tensor Apply(InitializerArgs args); | Tensor Apply(InitializerArgs args); | ||||
| } | } | ||||
| } | } | ||||
| @@ -14,12 +14,19 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class Ones : IInitializer | public class Ones : IInitializer | ||||
| { | { | ||||
| private TF_DataType dtype; | private TF_DataType dtype; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "Ones"; | |||||
| public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||||
| public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) | public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
| { | { | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| @@ -1,9 +1,14 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class Orthogonal : IInitializer | public class Orthogonal : IInitializer | ||||
| { | { | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "Orthogonal"; | |||||
| public IDictionary<string, object> Config => throw new NotImplementedException(); | |||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class RandomNormal : IInitializer | public class RandomNormal : IInitializer | ||||
| @@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers | |||||
| private int? seed; | private int? seed; | ||||
| private TF_DataType dtype; | private TF_DataType dtype; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "RandomNormal"; | |||||
| public IDictionary<string, object> Config => _config; | |||||
| public RandomNormal(float mean = 0.0f, | public RandomNormal(float mean = 0.0f, | ||||
| float stddev = 0.05f, | float stddev = 0.05f, | ||||
| int? seed = null, | int? seed = null, | ||||
| @@ -32,6 +39,11 @@ namespace Tensorflow.Operations.Initializers | |||||
| this.stddev = stddev; | this.stddev = stddev; | ||||
| this.seed = seed; | this.seed = seed; | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| _config = new Dictionary<string, object>(); | |||||
| _config["mean"] = this.mean; | |||||
| _config["stddev"] = this.stddev; | |||||
| _config["seed"] = this.seed; | |||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class RandomUniform : IInitializer | public class RandomUniform : IInitializer | ||||
| @@ -23,12 +25,22 @@ namespace Tensorflow.Operations.Initializers | |||||
| private float maxval; | private float maxval; | ||||
| private TF_DataType dtype; | private TF_DataType dtype; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "RandomUniform"; | |||||
| public IDictionary<string, object> Config => _config; | |||||
| public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) | public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) | ||||
| { | { | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| this.minval = minval; | this.minval = minval; | ||||
| this.maxval = maxval; | this.maxval = maxval; | ||||
| this.seed = seed; | this.seed = seed; | ||||
| _config = new Dictionary<string, object>(); | |||||
| _config["minval"] = this.minval; | |||||
| _config["maxval"] = this.maxval; | |||||
| _config["seed"] = this.seed; | |||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class TruncatedNormal : IInitializer | public class TruncatedNormal : IInitializer | ||||
| @@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers | |||||
| private int? seed; | private int? seed; | ||||
| private TF_DataType dtype; | private TF_DataType dtype; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public string ClassName => "TruncatedNormal"; | |||||
| public IDictionary<string, object> Config => _config; | |||||
| public TruncatedNormal(float mean = 0.0f, | public TruncatedNormal(float mean = 0.0f, | ||||
| float stddev = 1.0f, | float stddev = 1.0f, | ||||
| int? seed = null, | int? seed = null, | ||||
| @@ -32,6 +39,10 @@ namespace Tensorflow.Operations.Initializers | |||||
| this.stddev = stddev; | this.stddev = stddev; | ||||
| this.seed = seed; | this.seed = seed; | ||||
| this.dtype = dtype; | this.dtype = dtype; | ||||
| _config = new Dictionary<string, object>(); | |||||
| _config["mean"] = this.mean; | |||||
| _config["stddev"] = this.stddev; | |||||
| _config["seed"] = this.seed; | |||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| @@ -15,7 +15,9 @@ | |||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | |||||
| using System.Linq; | using System.Linq; | ||||
| using System.Linq.Expressions; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| @@ -30,6 +32,11 @@ namespace Tensorflow.Operations.Initializers | |||||
| protected int? _seed; | protected int? _seed; | ||||
| protected TF_DataType _dtype; | protected TF_DataType _dtype; | ||||
| protected bool _uniform; | protected bool _uniform; | ||||
| private readonly Dictionary<string, object> _config; | |||||
| public virtual string ClassName => "VarianceScaling"; | |||||
| public virtual IDictionary<string, object> Config => _config; | |||||
| public VarianceScaling(float factor = 2.0f, | public VarianceScaling(float factor = 2.0f, | ||||
| string mode = "FAN_IN", | string mode = "FAN_IN", | ||||
| @@ -50,6 +57,12 @@ namespace Tensorflow.Operations.Initializers | |||||
| _seed = seed; | _seed = seed; | ||||
| _dtype = dtype; | _dtype = dtype; | ||||
| _uniform = uniform; | _uniform = uniform; | ||||
| _config = new(); | |||||
| _config["scale"] = _scale; | |||||
| _config["mode"] = _mode; | |||||
| _config["distribution"] = _distribution; | |||||
| _config["seed"] = _seed; | |||||
| } | } | ||||
| public Tensor Apply(InitializerArgs args) | public Tensor Apply(InitializerArgs args) | ||||
| @@ -14,6 +14,8 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using System.Collections.Generic; | |||||
| namespace Tensorflow.Operations.Initializers | namespace Tensorflow.Operations.Initializers | ||||
| { | { | ||||
| public class Zeros : IInitializer | public class Zeros : IInitializer | ||||
| @@ -21,6 +23,9 @@ namespace Tensorflow.Operations.Initializers | |||||
| Shape shape; | Shape shape; | ||||
| TF_DataType dtype; | TF_DataType dtype; | ||||
| public string ClassName => "Zeros"; | |||||
| public IDictionary<string, object> Config => new Dictionary<string, object>(); | |||||
| public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) | public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) | ||||
| { | { | ||||
| this.shape = shape; | this.shape = shape; | ||||
| @@ -20,6 +20,7 @@ using Tensorflow.Keras; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Operations; | using Tensorflow.Operations; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| @@ -76,6 +77,8 @@ namespace Tensorflow | |||||
| public Shape BatchInputShape => throw new NotImplementedException(); | public Shape BatchInputShape => throw new NotImplementedException(); | ||||
| public TensorShapeConfig BuildInputShape => throw new NotImplementedException(); | |||||
| public TF_DataType DType => throw new NotImplementedException(); | public TF_DataType DType => throw new NotImplementedException(); | ||||
| protected bool built = false; | protected bool built = false; | ||||
| public bool Built => built; | public bool Built => built; | ||||
| @@ -144,7 +147,7 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| public LayerArgs get_config() | |||||
| public IKerasConfig get_config() | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| } | } | ||||
| @@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io</Description> | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | ||||
| <PackageReference Include="Newtonsoft.Json" Version="13.0.2" /> | |||||
| <PackageReference Include="Protobuf.Text" Version="0.5.0" /> | <PackageReference Include="Protobuf.Text" Version="0.5.0" /> | ||||
| <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -202,6 +202,24 @@ namespace Tensorflow | |||||
| _ => type.ToString() | _ => type.ToString() | ||||
| }; | }; | ||||
| public static string as_python_name(this TF_DataType type) | |||||
| => type switch | |||||
| { | |||||
| TF_DataType.TF_STRING => "str", | |||||
| TF_DataType.TF_UINT8 => "uint8", | |||||
| TF_DataType.TF_INT8 => "int8", | |||||
| TF_DataType.TF_UINT32 => "uint32", | |||||
| TF_DataType.TF_INT32 => "int32", | |||||
| TF_DataType.TF_UINT64 => "uint64", | |||||
| TF_DataType.TF_INT64 => "int64", | |||||
| TF_DataType.TF_FLOAT => "float32", | |||||
| TF_DataType.TF_DOUBLE => "float64", | |||||
| TF_DataType.TF_BOOL => "bool", | |||||
| TF_DataType.TF_RESOURCE => "resource", | |||||
| TF_DataType.TF_VARIANT => "variant", | |||||
| _ => type.ToString() | |||||
| }; | |||||
| public static int get_datatype_size(this TF_DataType type) | public static int get_datatype_size(this TF_DataType type) | ||||
| => type.as_base_dtype() switch | => type.as_base_dtype() switch | ||||
| { | { | ||||
| @@ -5,7 +5,7 @@ using Tensorflow.Train; | |||||
| namespace Tensorflow.Training | namespace Tensorflow.Training | ||||
| { | { | ||||
| public interface ITrackable | |||||
| public interface IWithTrackable | |||||
| { | { | ||||
| Trackable GetTrackable(); | Trackable GetTrackable(); | ||||
| } | } | ||||
| @@ -26,7 +26,7 @@ using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Train | namespace Tensorflow.Train | ||||
| { | { | ||||
| public abstract class Trackable: ITrackable | |||||
| public abstract class Trackable: IWithTrackable | |||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Corresponding to tensorflow/python/trackable/constants.py | /// Corresponding to tensorflow/python/trackable/constants.py | ||||
| @@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| public partial class Functional | public partial class Functional | ||||
| { | { | ||||
| public ModelConfig get_config() | |||||
| public override IKerasConfig get_config() | |||||
| { | { | ||||
| return get_network_config(); | return get_network_config(); | ||||
| } | } | ||||
| @@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine | |||||
| { | { | ||||
| Name = name | Name = name | ||||
| }; | }; | ||||
| var node_conversion_map = new Dictionary<string, int>(); | var node_conversion_map = new Dictionary<string, int>(); | ||||
| foreach (var layer in _self_tracked_trackables) | foreach (var layer in _self_tracked_trackables) | ||||
| { | { | ||||
| @@ -42,23 +42,26 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| var layer_configs = new List<LayerConfig>(); | var layer_configs = new List<LayerConfig>(); | ||||
| foreach (var layer in _self_tracked_trackables) | |||||
| using (SharedObjectSavingScope.Enter()) | |||||
| { | { | ||||
| var filtered_inbound_nodes = new List<NodeConfig>(); | |||||
| foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | |||||
| foreach (var layer in _self_tracked_trackables) | |||||
| { | { | ||||
| var node_key = _make_node_key(layer.Name, original_node_index); | |||||
| if (NetworkNodes.Contains(node_key) && !node.is_input) | |||||
| var filtered_inbound_nodes = new List<NodeConfig>(); | |||||
| foreach (var (original_node_index, node) in enumerate(layer.InboundNodes)) | |||||
| { | { | ||||
| var node_data = node.serialize(_make_node_key, node_conversion_map); | |||||
| filtered_inbound_nodes.append(node_data); | |||||
| var node_key = _make_node_key(layer.Name, original_node_index); | |||||
| if (NetworkNodes.Contains(node_key) && !node.is_input) | |||||
| { | |||||
| var node_data = node.serialize(_make_node_key, node_conversion_map); | |||||
| filtered_inbound_nodes.append(node_data); | |||||
| } | |||||
| } | } | ||||
| } | |||||
| var layer_config = generic_utils.serialize_layer_to_config(layer); | |||||
| layer_config.Name = layer.Name; | |||||
| layer_config.InboundNodes = filtered_inbound_nodes; | |||||
| layer_configs.Add(layer_config); | |||||
| var layer_config = generic_utils.serialize_layer_to_config(layer); | |||||
| layer_config.Name = layer.Name; | |||||
| layer_config.InboundNodes = filtered_inbound_nodes; | |||||
| layer_configs.Add(layer_config); | |||||
| } | |||||
| } | } | ||||
| config.Layers = layer_configs; | config.Layers = layer_configs; | ||||
| @@ -70,6 +70,7 @@ namespace Tensorflow.Keras.Engine | |||||
| this.inputs = inputs; | this.inputs = inputs; | ||||
| this.outputs = outputs; | this.outputs = outputs; | ||||
| built = true; | built = true; | ||||
| _buildInputShape = inputs.shape; | |||||
| if (outputs.Any(x => x.KerasHistory == null)) | if (outputs.Any(x => x.KerasHistory == null)) | ||||
| base_layer_utils.create_keras_history(outputs); | base_layer_utils.create_keras_history(outputs); | ||||
| @@ -357,5 +358,22 @@ namespace Tensorflow.Keras.Engine | |||||
| return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) | return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) | ||||
| .ToDictionary(x => x.Key, x => x.Value); | .ToDictionary(x => x.Key, x => x.Value); | ||||
| } | } | ||||
| protected override void _init_set_name(string name, bool zero_based = true) | |||||
| { | |||||
| if (string.IsNullOrEmpty(name)) | |||||
| { | |||||
| string class_name = GetType().Name; | |||||
| if (this.GetType() == typeof(Functional)) | |||||
| { | |||||
| class_name = "Model"; | |||||
| } | |||||
| this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based); | |||||
| } | |||||
| else | |||||
| { | |||||
| this.name = name; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -61,6 +61,7 @@ namespace Tensorflow.Keras.Engine | |||||
| /// Provides information about which inputs are compatible with the layer. | /// Provides information about which inputs are compatible with the layer. | ||||
| /// </summary> | /// </summary> | ||||
| protected InputSpec inputSpec; | protected InputSpec inputSpec; | ||||
| public InputSpec InputSpec => inputSpec; | |||||
| bool dynamic = true; | bool dynamic = true; | ||||
| public bool SupportsMasking { get; set; } | public bool SupportsMasking { get; set; } | ||||
| protected List<IVariableV1> _trainable_weights; | protected List<IVariableV1> _trainable_weights; | ||||
| @@ -79,6 +80,8 @@ namespace Tensorflow.Keras.Engine | |||||
| protected bool computePreviousMask; | protected bool computePreviousMask; | ||||
| protected List<Operation> updates; | protected List<Operation> updates; | ||||
| public Shape BatchInputShape => args.BatchInputShape; | public Shape BatchInputShape => args.BatchInputShape; | ||||
| protected TensorShapeConfig _buildInputShape = null; | |||||
| public TensorShapeConfig BuildInputShape => _buildInputShape; | |||||
| List<INode> inboundNodes; | List<INode> inboundNodes; | ||||
| public List<INode> InboundNodes => inboundNodes; | public List<INode> InboundNodes => inboundNodes; | ||||
| @@ -223,6 +226,7 @@ namespace Tensorflow.Keras.Engine | |||||
| public virtual void build(Shape input_shape) | public virtual void build(Shape input_shape) | ||||
| { | { | ||||
| _buildInputShape = input_shape; | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| @@ -310,7 +314,7 @@ namespace Tensorflow.Keras.Engine | |||||
| public List<IVariableV1> Variables => weights; | public List<IVariableV1> Variables => weights; | ||||
| public virtual LayerArgs get_config() | |||||
| public virtual IKerasConfig get_config() | |||||
| => args; | => args; | ||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Tensorflow.Functions; | using Tensorflow.Functions; | ||||
| using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
| using Tensorflow.Keras.Saving; | |||||
| using Tensorflow.Keras.Saving.SavedModel; | using Tensorflow.Keras.Saving.SavedModel; | ||||
| using Tensorflow.ModelSaving; | using Tensorflow.ModelSaving; | ||||
| @@ -30,7 +31,10 @@ namespace Tensorflow.Keras.Engine | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||||
| using (SharedObjectSavingScope.Enter()) | |||||
| { | |||||
| KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -25,6 +25,7 @@ namespace Tensorflow.Keras.Layers { | |||||
| { | { | ||||
| throw new ValueError("Alpha must be a number greater than 0."); | throw new ValueError("Alpha must be a number greater than 0."); | ||||
| } | } | ||||
| _buildInputShape = input_shape; | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| @@ -14,6 +14,7 @@ namespace Tensorflow.Keras.Layers { | |||||
| } | } | ||||
| public override void build(Shape input_shape) | public override void build(Shape input_shape) | ||||
| { | { | ||||
| _buildInputShape = input_shape; | |||||
| built = true; | built = true; | ||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| @@ -16,10 +16,11 @@ namespace Tensorflow.Keras.Layers { | |||||
| // SELU has no arguments | // SELU has no arguments | ||||
| } | } | ||||
| public override void build(Shape input_shape) { | public override void build(Shape input_shape) { | ||||
| if ( alpha < 0f ) { | |||||
| throw new ValueError("Alpha must be a number greater than 0."); | |||||
| } | |||||
| built = true; | |||||
| if ( alpha < 0f ) { | |||||
| throw new ValueError("Alpha must be a number greater than 0."); | |||||
| } | |||||
| _buildInputShape = input_shape; | |||||
| built = true; | |||||
| } | } | ||||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | ||||
| Tensor output = inputs; | Tensor output = inputs; | ||||
| @@ -4,6 +4,7 @@ using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers | namespace Tensorflow.Keras.Layers | ||||
| { | { | ||||
| @@ -146,7 +147,7 @@ namespace Tensorflow.Keras.Layers | |||||
| return scores; | return scores; | ||||
| } | } | ||||
| public override LayerArgs get_config() => this.args; | |||||
| public override IKerasConfig get_config() => this.args; | |||||
| //var config = new Dictionary<object, object> { | //var config = new Dictionary<object, object> { | ||||
| // { | // { | ||||
| // "use_scale", | // "use_scale", | ||||
| @@ -5,6 +5,7 @@ using static Tensorflow.KerasApi; | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.Saving; | |||||
| /// <summary> | /// <summary> | ||||
| /// Base class for attention layers that can be used in sequence DNN/CNN models. | /// Base class for attention layers that can be used in sequence DNN/CNN models. | ||||
| @@ -252,6 +253,6 @@ namespace Tensorflow.Keras.Layers | |||||
| return tf.logical_and(x, y); | return tf.logical_and(x, y); | ||||
| } | } | ||||
| public override LayerArgs get_config() => this.args; | |||||
| public override IKerasConfig get_config() => this.args; | |||||
| } | } | ||||
| } | } | ||||
| @@ -49,6 +49,7 @@ namespace Tensorflow.Keras.Layers | |||||
| initializer: bias_initializer, | initializer: bias_initializer, | ||||
| trainable: true); | trainable: true); | ||||
| built = true; | built = true; | ||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| @@ -98,6 +98,7 @@ namespace Tensorflow.Keras.Layers | |||||
| name: tf_op_name); | name: tf_op_name); | ||||
| built = true; | built = true; | ||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) | ||||
| @@ -43,6 +43,7 @@ namespace Tensorflow.Keras.Layers | |||||
| public override void build(Shape input_shape) | public override void build(Shape input_shape) | ||||
| { | { | ||||
| _buildInputShape = input_shape; | |||||
| var last_dim = input_shape.dims.Last(); | var last_dim = input_shape.dims.Last(); | ||||
| var axes = new Dictionary<int, int>(); | var axes = new Dictionary<int, int>(); | ||||
| axes[-1] = (int)last_dim; | axes[-1] = (int)last_dim; | ||||
| @@ -62,6 +62,7 @@ namespace Tensorflow.Keras.Layers | |||||
| name: "embeddings"); | name: "embeddings"); | ||||
| tf.Context.graph_mode(); | tf.Context.graph_mode(); | ||||
| built = true; | built = true; | ||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| @@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Layers { | |||||
| throw new ValueError("The `cropping` argument must be a tuple of 2 integers."); | throw new ValueError("The `cropping` argument must be a tuple of 2 integers."); | ||||
| } | } | ||||
| built = true; | built = true; | ||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| @@ -13,7 +13,8 @@ namespace Tensorflow.Keras.Layers { | |||||
| this.args = args; | this.args = args; | ||||
| } | } | ||||
| public override void build(Shape input_shape) { | public override void build(Shape input_shape) { | ||||
| built = true; | |||||
| built = true; | |||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | ||||
| Tensor output = inputs; | Tensor output = inputs; | ||||
| @@ -12,7 +12,8 @@ namespace Tensorflow.Keras.Layers { | |||||
| } | } | ||||
| public override void build(Shape input_shape) { | public override void build(Shape input_shape) { | ||||
| built = true; | |||||
| built = true; | |||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { | ||||
| @@ -300,7 +300,8 @@ namespace Tensorflow.Keras.Layers | |||||
| => new Dense(new DenseArgs | => new Dense(new DenseArgs | ||||
| { | { | ||||
| Units = units, | Units = units, | ||||
| Activation = GetActivationByName("linear") | |||||
| Activation = GetActivationByName("linear"), | |||||
| ActivationName = "linear" | |||||
| }); | }); | ||||
| /// <summary> | /// <summary> | ||||
| @@ -321,6 +322,7 @@ namespace Tensorflow.Keras.Layers | |||||
| { | { | ||||
| Units = units, | Units = units, | ||||
| Activation = GetActivationByName(activation), | Activation = GetActivationByName(activation), | ||||
| ActivationName = activation, | |||||
| InputShape = input_shape | InputShape = input_shape | ||||
| }); | }); | ||||
| @@ -37,6 +37,7 @@ namespace Tensorflow.Keras.Layers | |||||
| }).ToArray(); | }).ToArray(); | ||||
| shape_set.Add(shape); | shape_set.Add(shape); | ||||
| }*/ | }*/ | ||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| protected override Tensors _merge_function(Tensors inputs) | protected override Tensors _merge_function(Tensors inputs) | ||||
| @@ -17,6 +17,7 @@ namespace Tensorflow.Keras.Layers | |||||
| public override void build(Shape input_shape) | public override void build(Shape input_shape) | ||||
| { | { | ||||
| // output_shape = input_shape.dims[1^]; | // output_shape = input_shape.dims[1^]; | ||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| @@ -118,6 +118,7 @@ namespace Tensorflow.Keras.Layers | |||||
| throw new NotImplementedException("build when renorm is true"); | throw new NotImplementedException("build when renorm is true"); | ||||
| built = true; | built = true; | ||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| public override Shape ComputeOutputShape(Shape input_shape) | public override Shape ComputeOutputShape(Shape input_shape) | ||||
| @@ -81,6 +81,7 @@ namespace Tensorflow.Keras.Layers | |||||
| _fused = _fused_can_be_used(ndims); | _fused = _fused_can_be_used(ndims); | ||||
| built = true; | built = true; | ||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| bool _fused_can_be_used(int ndims) | bool _fused_can_be_used(int ndims) | ||||
| @@ -24,6 +24,7 @@ namespace Tensorflow.Keras.Layers { | |||||
| permute = new int[input_shape.rank]; | permute = new int[input_shape.rank]; | ||||
| dims.CopyTo(permute, 1); | dims.CopyTo(permute, 1); | ||||
| built = true; | built = true; | ||||
| _buildInputShape = input_shape; | |||||
| } | } | ||||
| protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) | ||||
| { | { | ||||
| @@ -18,6 +18,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| public override void build(Shape input_shape) | public override void build(Shape input_shape) | ||||
| { | { | ||||
| var input_dim = input_shape[-1]; | var input_dim = input_shape[-1]; | ||||
| _buildInputShape = input_shape; | |||||
| kernel = add_weight("kernel", (input_shape[-1], args.Units), | kernel = add_weight("kernel", (input_shape[-1], args.Units), | ||||
| initializer: args.KernelInitializer | initializer: args.KernelInitializer | ||||
| @@ -4,6 +4,7 @@ using System.ComponentModel; | |||||
| using Tensorflow.Keras.ArgsDefinition; | using Tensorflow.Keras.ArgsDefinition; | ||||
| using Tensorflow.Keras.ArgsDefinition.Rnn; | using Tensorflow.Keras.ArgsDefinition.Rnn; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace Tensorflow.Keras.Layers.Rnn | namespace Tensorflow.Keras.Layers.Rnn | ||||
| { | { | ||||
| @@ -136,7 +137,7 @@ namespace Tensorflow.Keras.Layers.Rnn | |||||
| // self.built = True | // self.built = True | ||||
| } | } | ||||
| public override LayerArgs get_config() | |||||
| public override IKerasConfig get_config() | |||||
| { | { | ||||
| throw new NotImplementedException(); | throw new NotImplementedException(); | ||||
| //def get_config(self): | //def get_config(self): | ||||
| @@ -79,7 +79,7 @@ public partial class KerasSavedModelUtils | |||||
| var path = node_paths[node]; | var path = node_paths[node]; | ||||
| string node_path; | string node_path; | ||||
| if (path is null) | |||||
| if (path is null || path.Count() == 0) | |||||
| { | { | ||||
| node_path = "root"; | node_path = "root"; | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using Newtonsoft.Json; | using Newtonsoft.Json; | ||||
| using Newtonsoft.Json.Linq; | using Newtonsoft.Json.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Keras.Layers; | using Tensorflow.Keras.Layers; | ||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| @@ -85,31 +86,38 @@ public class LayerSavedModelSaver: SavedModelSaver | |||||
| JObject metadata = new JObject(); | JObject metadata = new JObject(); | ||||
| metadata["name"] = _layer.Name; | metadata["name"] = _layer.Name; | ||||
| metadata["trainable"] = _layer.Trainable; | metadata["trainable"] = _layer.Trainable; | ||||
| // metadata["expects_training_arg"] = _obj._expects_training_arg; | |||||
| // metadata["dtype"] = policy.serialize(_obj._dtype_policy) | |||||
| // TODO: implement `expects_training_arg`. | |||||
| metadata["expects_training_arg"] = false; | |||||
| metadata["dtype"] = _layer.DType.as_python_name(); | |||||
| metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); | metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); | ||||
| // metadata["stateful"] = _obj.stateful; | // metadata["stateful"] = _obj.stateful; | ||||
| // metadata["must_restore_from_config"] = _obj.must_restore_from_config; | // metadata["must_restore_from_config"] = _obj.must_restore_from_config; | ||||
| // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; | // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; | ||||
| metadata["autocast"] = _layer.AutoCast; | metadata["autocast"] = _layer.AutoCast; | ||||
| var temp = JObject.FromObject(get_serialized(_layer)); | |||||
| metadata.Merge(temp, new JsonMergeSettings | |||||
| if(_layer.InputSpec is not null) | |||||
| { | |||||
| metadata["input_spec"] = generic_utils.serialize_keras_object(_layer.InputSpec); | |||||
| } | |||||
| metadata.Merge(get_serialized(_layer), new JsonMergeSettings | |||||
| { | { | ||||
| // Handle conflicts by using values from obj2 | // Handle conflicts by using values from obj2 | ||||
| MergeArrayHandling = MergeArrayHandling.Merge | MergeArrayHandling = MergeArrayHandling.Merge | ||||
| }); | }); | ||||
| // skip the check of `input_spec` and `build_input_shape` for the lack of members. | // skip the check of `input_spec` and `build_input_shape` for the lack of members. | ||||
| // skip the check of `activity_regularizer` for the type problem. | // skip the check of `activity_regularizer` for the type problem. | ||||
| if(_layer.BuildInputShape is not null) | |||||
| { | |||||
| metadata["build_input_shape"] = JToken.FromObject(_layer.BuildInputShape); | |||||
| } | |||||
| return metadata.ToString(); | return metadata.ToString(); | ||||
| } | } | ||||
| } | } | ||||
| public static IDictionary<string, object> get_serialized(Layer obj) | |||||
| public static JObject get_serialized(Layer obj) | |||||
| { | { | ||||
| // TODO: complete the implmentation (need to revise `get_config`). | |||||
| return new Dictionary<string, object>(); | |||||
| //return generic_utils.serialize_keras_object(obj); | |||||
| return generic_utils.serialize_keras_object(obj); | |||||
| } | } | ||||
| } | } | ||||
| @@ -135,18 +143,19 @@ public class InputLayerSavedModelSaver: SavedModelSaver | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| if(_obj is not Layer) | |||||
| if(_obj is not InputLayer) | |||||
| { | { | ||||
| throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); | throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); | ||||
| } | } | ||||
| var layer = (Layer)_obj; | |||||
| var layer = (InputLayer)_obj; | |||||
| var config = (layer.get_config() as InputLayerArgs)!; | |||||
| var info = new | var info = new | ||||
| { | { | ||||
| class_name = layer.GetType().Name, | class_name = layer.GetType().Name, | ||||
| name = layer.Name, | name = layer.Name, | ||||
| dtype = layer.DType, | dtype = layer.DType, | ||||
| //sparse = layer.sparse, | |||||
| //ragged = layer.ragged, | |||||
| sparse = config.Sparse, | |||||
| ragged = config.Ragged, | |||||
| batch_input_shape = layer.BatchInputShape, | batch_input_shape = layer.BatchInputShape, | ||||
| config = layer.get_config() | config = layer.get_config() | ||||
| }; | }; | ||||
| @@ -1,15 +0,0 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| namespace Tensorflow.Keras.Saving | |||||
| { | |||||
| public class TensorShapeConfig | |||||
| { | |||||
| public string ClassName { get; set; } | |||||
| public int?[] Items { get; set; } | |||||
| public static implicit operator Shape(TensorShapeConfig shape) | |||||
| => shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray()); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,125 @@ | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Reflection; | |||||
| using System.Text; | |||||
| using Tensorflow.Keras.Saving.SavedModel; | |||||
| namespace Tensorflow.Keras.Saving | |||||
| { | |||||
| // TODO: make it thread safe. | |||||
| public class SharedObjectSavingScope: IDisposable | |||||
| { | |||||
| private class WeakReferenceEqualityComparer: IEqualityComparer<WeakReference<object>> | |||||
| { | |||||
| public bool Equals(WeakReference<object> x, WeakReference<object> y) | |||||
| { | |||||
| if(!x.TryGetTarget(out var tx)) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| if(!y.TryGetTarget(out var ty)) | |||||
| { | |||||
| return false; | |||||
| } | |||||
| return tx.Equals(ty); | |||||
| } | |||||
| public int GetHashCode(WeakReference<object> obj) | |||||
| { | |||||
| if (!obj.TryGetTarget(out var w)) | |||||
| { | |||||
| return 0; | |||||
| } | |||||
| return w.GetHashCode(); | |||||
| } | |||||
| } | |||||
| private static SharedObjectSavingScope? _instance = null; | |||||
| private readonly Dictionary<WeakReference<object>, int> _shared_object_ids= new Dictionary<WeakReference<object>, int>(); | |||||
| private int _currentId = 0; | |||||
| /// <summary> | |||||
| /// record how many times the scope is nested. | |||||
| /// </summary> | |||||
| private int _nestedDepth = 0; | |||||
| private SharedObjectSavingScope() | |||||
| { | |||||
| } | |||||
| public static SharedObjectSavingScope Enter() | |||||
| { | |||||
| if(_instance is not null) | |||||
| { | |||||
| _instance._nestedDepth++; | |||||
| return _instance; | |||||
| } | |||||
| else | |||||
| { | |||||
| _instance = new SharedObjectSavingScope(); | |||||
| _instance._nestedDepth++; | |||||
| return _instance; | |||||
| } | |||||
| } | |||||
| public static SharedObjectSavingScope GetScope() | |||||
| { | |||||
| return _instance; | |||||
| } | |||||
| public int GetId(object? obj) | |||||
| { | |||||
| if(obj is null) | |||||
| { | |||||
| return _currentId++; | |||||
| } | |||||
| var maybe_key = _shared_object_ids.Keys.SingleOrDefault(x => new WeakReferenceEqualityComparer().Equals(x, new WeakReference<object>(obj))); | |||||
| if (maybe_key is not null) | |||||
| { | |||||
| return _shared_object_ids[maybe_key]; | |||||
| } | |||||
| _shared_object_ids[new WeakReference<object>(obj)] = _currentId++; | |||||
| return _currentId; | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| _nestedDepth--; | |||||
| if(_nestedDepth== 0) | |||||
| { | |||||
| _instance = null; | |||||
| } | |||||
| } | |||||
| } | |||||
| public static class serialize_utils | |||||
| { | |||||
| public static readonly string SHARED_OBJECT_KEY = "shared_object_id"; | |||||
| /// <summary> | |||||
| /// Returns the serialization of the class with the given config. | |||||
| /// </summary> | |||||
| /// <param name="class_name"></param> | |||||
| /// <param name="config"></param> | |||||
| /// <param name="obj"></param> | |||||
| /// <param name="shared_object_id"></param> | |||||
| /// <returns></returns> | |||||
| public static JObject serialize_keras_class_and_config(string class_name, JToken config, object? obj = null, int? shared_object_id = null) | |||||
| { | |||||
| JObject res = new JObject(); | |||||
| res["class_name"] = class_name; | |||||
| res["config"] = config; | |||||
| if(shared_object_id is not null) | |||||
| { | |||||
| res[SHARED_OBJECT_KEY] = shared_object_id!; | |||||
| } | |||||
| var scope = SharedObjectSavingScope.GetScope(); | |||||
| if(scope is not null && obj is not null) | |||||
| { | |||||
| res[SHARED_OBJECT_KEY] = scope.GetId(obj); | |||||
| } | |||||
| return res; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Utils | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. | |||||
| /// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. (correponding to `backend.unique_object_name` of python.) | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| @@ -14,10 +14,14 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Newtonsoft.Json; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using System; | using System; | ||||
| using System.Collections; | using System.Collections; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Keras.ArgsDefinition; | |||||
| using Tensorflow.Keras.Saving; | using Tensorflow.Keras.Saving; | ||||
| namespace Tensorflow.Keras.Utils | namespace Tensorflow.Keras.Utils | ||||
| @@ -32,13 +36,21 @@ namespace Tensorflow.Keras.Utils | |||||
| public static LayerConfig serialize_layer_to_config(ILayer instance) | public static LayerConfig serialize_layer_to_config(ILayer instance) | ||||
| { | { | ||||
| var config = instance.get_config(); | var config = instance.get_config(); | ||||
| Debug.Assert(config is LayerArgs); | |||||
| return new LayerConfig | return new LayerConfig | ||||
| { | { | ||||
| Config = config, | |||||
| Config = config as LayerArgs, | |||||
| ClassName = instance.GetType().Name | ClassName = instance.GetType().Name | ||||
| }; | }; | ||||
| } | } | ||||
| public static JObject serialize_keras_object(IKerasConfigable instance) | |||||
| { | |||||
| var config = JToken.FromObject(instance.get_config()); | |||||
| // TODO: change the class_name to registered name, instead of system class name. | |||||
| return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance); | |||||
| } | |||||
| public static string to_snake_case(string name) | public static string to_snake_case(string name) | ||||
| { | { | ||||
| return string.Concat(name.Select((x, i) => | return string.Concat(name.Select((x, i) => | ||||
| @@ -1,6 +1,8 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using System.Diagnostics; | |||||
| using static Tensorflow.KerasApi; | using static Tensorflow.KerasApi; | ||||
| using Tensorflow.Keras.Saving; | |||||
| namespace TensorFlowNET.Keras.UnitTest | namespace TensorFlowNET.Keras.UnitTest | ||||
| { | { | ||||
| @@ -15,7 +17,8 @@ namespace TensorFlowNET.Keras.UnitTest | |||||
| { | { | ||||
| var model = GetFunctionalModel(); | var model = GetFunctionalModel(); | ||||
| var config = model.get_config(); | var config = model.get_config(); | ||||
| var new_model = keras.models.from_config(config); | |||||
| Debug.Assert(config is ModelConfig); | |||||
| var new_model = keras.models.from_config(config as ModelConfig); | |||||
| Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); | Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); | ||||
| } | } | ||||
| @@ -15,17 +15,14 @@ using Tensorflow.Keras.Layers; | |||||
| using Tensorflow.Keras.Losses; | using Tensorflow.Keras.Losses; | ||||
| using Tensorflow.Keras.Metrics; | using Tensorflow.Keras.Metrics; | ||||
| using Tensorflow.Keras.Optimizers; | using Tensorflow.Keras.Optimizers; | ||||
| using Tensorflow.Operations; | |||||
| namespace TensorFlowNET.Keras.UnitTest; | namespace TensorFlowNET.Keras.UnitTest; | ||||
| // class MNISTLoader | |||||
| // { | |||||
| // public MNISTLoader() | |||||
| // { | |||||
| // var mnist = new MnistModelLoader() | |||||
| // | |||||
| // } | |||||
| // } | |||||
| public static class AutoGraphExtension | |||||
| { | |||||
| } | |||||
| [TestClass] | [TestClass] | ||||
| public class SaveTest | public class SaveTest | ||||
| @@ -42,6 +39,8 @@ public class SaveTest | |||||
| model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); | model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); | ||||
| var g = ops.get_default_graph(); | |||||
| var data_loader = new MnistModelLoader(); | var data_loader = new MnistModelLoader(); | ||||
| var num_epochs = 1; | var num_epochs = 1; | ||||
| var batch_size = 50; | var batch_size = 50; | ||||
| @@ -50,11 +49,34 @@ public class SaveTest | |||||
| { | { | ||||
| TrainDir = "mnist", | TrainDir = "mnist", | ||||
| OneHot = false, | OneHot = false, | ||||
| ValidationSize = 0, | |||||
| ValidationSize = 50000, | |||||
| }).Result; | }).Result; | ||||
| model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); | ||||
| model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); | model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); | ||||
| } | } | ||||
| [TestMethod] | |||||
| public void Temp() | |||||
| { | |||||
| var graph = new Graph(); | |||||
| var g = graph.as_default(); | |||||
| //var input_tensor = array_ops.placeholder(TF_DataType.TF_FLOAT, new int[] { 1 }, "test_string_tensor"); | |||||
| var input_tensor = tf.placeholder(tf.int32, new int[] { 1 }, "aa"); | |||||
| var wrapped_func = tf.autograph.to_graph(func); | |||||
| var res = wrapped_func(input_tensor); | |||||
| g.Exit(); | |||||
| } | |||||
| private Tensor func(Tensor tensor) | |||||
| { | |||||
| return gen_ops.neg(tensor); | |||||
| //return array_ops.identity(tensor); | |||||
| //tf.device("cpu:0"); | |||||
| //using (ops.control_dependencies(new object[] { res.op })) | |||||
| //{ | |||||
| // return array_ops.identity(tensor); | |||||
| //} | |||||
| } | |||||
| } | } | ||||