Browse Source

Add more implementations to the keras part of pb model save.

pull/976/head
AsakusaRinne 2 years ago
parent
commit
2ab0bdbc86
70 changed files with 849 additions and 109 deletions
  1. +13
    -4
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs
  2. +19
    -0
      src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
  3. +41
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs
  4. +15
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs
  5. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs
  6. +2
    -1
      src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs
  7. +17
    -14
      src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs
  8. +4
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs
  9. +4
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs
  10. +5
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs
  11. +50
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs
  12. +48
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs
  13. +73
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs
  14. +67
    -0
      src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs
  15. +30
    -1
      src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs
  16. +3
    -2
      src/TensorFlowNET.Core/Keras/Layers/ILayer.cs
  17. +15
    -0
      src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs
  18. +7
    -2
      src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs
  19. +7
    -2
      src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs
  20. +5
    -2
      src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs
  21. +21
    -0
      src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs
  22. +10
    -1
      src/TensorFlowNET.Core/NumPy/Axis.cs
  23. +3
    -0
      src/TensorFlowNET.Core/Numpy/Shape.cs
  24. +10
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Constant.cs
  25. +9
    -1
      src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs
  26. +7
    -0
      src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs
  27. +7
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Ones.cs
  28. +5
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs
  29. +12
    -0
      src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs
  30. +12
    -0
      src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs
  31. +11
    -0
      src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs
  32. +13
    -0
      src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs
  33. +5
    -0
      src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs
  34. +4
    -1
      src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs
  35. +1
    -0
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  36. +18
    -0
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  37. +1
    -1
      src/TensorFlowNET.Core/Training/IWithTrackable.cs
  38. +1
    -1
      src/TensorFlowNET.Core/Training/Trackable.cs
  39. +17
    -14
      src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs
  40. +18
    -0
      src/TensorFlowNET.Keras/Engine/Functional.cs
  41. +5
    -1
      src/TensorFlowNET.Keras/Engine/Layer.cs
  42. +5
    -1
      src/TensorFlowNET.Keras/Engine/Model.Save.cs
  43. +1
    -0
      src/TensorFlowNET.Keras/Layers/Activation/ELU.cs
  44. +1
    -0
      src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs
  45. +5
    -4
      src/TensorFlowNET.Keras/Layers/Activation/SELU.cs
  46. +2
    -1
      src/TensorFlowNET.Keras/Layers/Attention/Attention.cs
  47. +2
    -1
      src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs
  48. +1
    -0
      src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
  49. +1
    -0
      src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
  50. +1
    -0
      src/TensorFlowNET.Keras/Layers/Core/Dense.cs
  51. +1
    -0
      src/TensorFlowNET.Keras/Layers/Core/Embedding.cs
  52. +1
    -0
      src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs
  53. +2
    -1
      src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs
  54. +2
    -1
      src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs
  55. +3
    -1
      src/TensorFlowNET.Keras/Layers/LayersApi.cs
  56. +1
    -0
      src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs
  57. +1
    -0
      src/TensorFlowNET.Keras/Layers/Merging/Merge.cs
  58. +1
    -0
      src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
  59. +1
    -0
      src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs
  60. +1
    -0
      src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs
  61. +1
    -0
      src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs
  62. +2
    -1
      src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs
  63. +1
    -1
      src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs
  64. +21
    -12
      src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs
  65. +0
    -15
      src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs
  66. +125
    -0
      src/TensorFlowNET.Keras/Saving/serialization.cs
  67. +1
    -1
      src/TensorFlowNET.Keras/Utils/base_layer_utils.cs
  68. +13
    -1
      src/TensorFlowNET.Keras/Utils/generic_utils.cs
  69. +4
    -1
      test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs
  70. +31
    -9
      test/TensorFlowNET.Keras.UnitTest/SaveTest.cs

+ 13
- 4
src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs View File

@@ -1,9 +1,18 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.ArgsDefinition { namespace Tensorflow.Keras.ArgsDefinition {
public class SoftmaxArgs : LayerArgs {
public Axis axis { get; set; } = -1;
}
public class SoftmaxArgs : LayerArgs
{
[JsonProperty("axis")]
public Axis axis { get; set; } = -1;
[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("trainable")]
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
}
} }

+ 19
- 0
src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs View File

@@ -0,0 +1,19 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.ArgsDefinition
{
public class AutoSerializeLayerArgs: LayerArgs
{
[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
[JsonProperty("trainable")]
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
}
}

+ 41
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs View File

@@ -1,13 +1,18 @@
using System;
using Newtonsoft.Json;
using System;
using System.Xml.Linq;
using Tensorflow.Operations.Initializers;
using static Tensorflow.Binding; using static Tensorflow.Binding;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
// TODO: `activity_regularizer`
public class DenseArgs : LayerArgs public class DenseArgs : LayerArgs
{ {
/// <summary> /// <summary>
/// Positive integer, dimensionality of the output space. /// Positive integer, dimensionality of the output space.
/// </summary> /// </summary>
[JsonProperty("units")]
public int Units { get; set; } public int Units { get; set; }


/// <summary> /// <summary>
@@ -15,39 +20,74 @@ namespace Tensorflow.Keras.ArgsDefinition
/// </summary> /// </summary>
public Activation Activation { get; set; } public Activation Activation { get; set; }


private string _activationName;
[JsonProperty("activation")]
public string ActivationName
{
get
{
if (string.IsNullOrEmpty(_activationName))
{
return Activation.Method.Name;
}
else
{
return _activationName;
}
}
set
{
_activationName = value;
}
}

/// <summary> /// <summary>
/// Whether the layer uses a bias vector. /// Whether the layer uses a bias vector.
/// </summary> /// </summary>
[JsonProperty("use_bias")]
public bool UseBias { get; set; } = true; public bool UseBias { get; set; } = true;


/// <summary> /// <summary>
/// Initializer for the `kernel` weights matrix. /// Initializer for the `kernel` weights matrix.
/// </summary> /// </summary>
[JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer; public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;


/// <summary> /// <summary>
/// Initializer for the bias vector. /// Initializer for the bias vector.
/// </summary> /// </summary>
[JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer; public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;


/// <summary> /// <summary>
/// Regularizer function applied to the `kernel` weights matrix. /// Regularizer function applied to the `kernel` weights matrix.
/// </summary> /// </summary>
[JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; } public IRegularizer KernelRegularizer { get; set; }


/// <summary> /// <summary>
/// Regularizer function applied to the bias vector. /// Regularizer function applied to the bias vector.
/// </summary> /// </summary>
[JsonProperty("bias_regularizer")]
public IRegularizer BiasRegularizer { get; set; } public IRegularizer BiasRegularizer { get; set; }


/// <summary> /// <summary>
/// Constraint function applied to the `kernel` weights matrix. /// Constraint function applied to the `kernel` weights matrix.
/// </summary> /// </summary>
[JsonProperty("kernel_constraint")]
public Action KernelConstraint { get; set; } public Action KernelConstraint { get; set; }


/// <summary> /// <summary>
/// Constraint function applied to the bias vector. /// Constraint function applied to the bias vector.
/// </summary> /// </summary>
[JsonProperty("bias_constraint")]
public Action BiasConstraint { get; set; } public Action BiasConstraint { get; set; }

[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("trainable")]
public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
} }
} }

+ 15
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/InputLayerArgs.cs View File

@@ -1,9 +1,22 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using Tensorflow.Keras.Common;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class InputLayerArgs : LayerArgs public class InputLayerArgs : LayerArgs
{ {
[JsonIgnore]
public Tensor InputTensor { get; set; } public Tensor InputTensor { get; set; }
public bool Sparse { get; set; }
[JsonProperty("sparse")]
public virtual bool Sparse { get; set; }
[JsonProperty("ragged")]
public bool Ragged { get; set; } public bool Ragged { get; set; }
[JsonProperty("name")]
public override string Name { get => base.Name; set => base.Name = value; }
[JsonProperty("dtype")]
public override TF_DataType DType { get => base.DType; set => base.DType = value; }
[JsonProperty("batch_input_shape", NullValueHandling = NullValueHandling.Ignore)]
public override Shape BatchInputShape { get => base.BatchInputShape; set => base.BatchInputShape = value; }
} }
} }

+ 2
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataAdapterArgs.cs View File

@@ -1,8 +1,9 @@
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class DataAdapterArgs
public class DataAdapterArgs: IKerasConfig
{ {
public Tensor X { get; set; } public Tensor X { get; set; }
public Tensor Y { get; set; } public Tensor Y { get; set; }


+ 2
- 1
src/TensorFlowNET.Core/Keras/ArgsDefinition/DataHandlerArgs.cs View File

@@ -1,8 +1,9 @@
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.ArgsDefinition namespace Tensorflow.Keras.ArgsDefinition
{ {
public class DataHandlerArgs
public class DataHandlerArgs: IKerasConfig
{ {
public Tensor X { get; set; } public Tensor X { get; set; }
public Tensor Y { get; set; } public Tensor Y { get; set; }


+ 17
- 14
src/TensorFlowNET.Core/Keras/ArgsDefinition/LayerArgs.cs View File

@@ -1,51 +1,54 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class LayerArgs
[JsonObject(MemberSerialization.OptIn)]
public class LayerArgs: IKerasConfig
{ {
/// <summary> /// <summary>
/// Indicates whether the layer's weights are updated during training /// Indicates whether the layer's weights are updated during training
/// and whether the layer's updates are run during training. /// and whether the layer's updates are run during training.
/// </summary> /// </summary>
public bool Trainable { get; set; } = true;

public string Name { get; set; }
public virtual bool Trainable { get; set; } = true;
public virtual string Name { get; set; }


/// <summary> /// <summary>
/// Only applicable to input layers. /// Only applicable to input layers.
/// </summary> /// </summary>
public TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT;
public virtual TF_DataType DType { get; set; } = TF_DataType.TF_FLOAT;


/// <summary> /// <summary>
/// Whether the `call` method can be used to build a TF graph without issues. /// Whether the `call` method can be used to build a TF graph without issues.
/// This attribute has no effect if the model is created using the Functional /// This attribute has no effect if the model is created using the Functional
/// API. Instead, `model.dynamic` is determined based on the internal layers. /// API. Instead, `model.dynamic` is determined based on the internal layers.
/// </summary> /// </summary>
public bool Dynamic { get; set; } = false;
public virtual bool Dynamic { get; set; } = false;


/// <summary> /// <summary>
/// Only applicable to input layers. /// Only applicable to input layers.
/// </summary> /// </summary>
public Shape InputShape { get; set; }
public virtual Shape InputShape { get; set; }


/// <summary> /// <summary>
/// Only applicable to input layers. /// Only applicable to input layers.
/// </summary> /// </summary>
public Shape BatchInputShape { get; set; }
public virtual Shape BatchInputShape { get; set; }


public int BatchSize { get; set; } = -1;
public virtual int BatchSize { get; set; } = -1;


/// <summary> /// <summary>
/// Initial weight values. /// Initial weight values.
/// </summary> /// </summary>
public float[] Weights { get; set; }
public virtual float[] Weights { get; set; }


/// <summary> /// <summary>
/// Regularizer function applied to the output of the layer(its "activation"). /// Regularizer function applied to the output of the layer(its "activation").
/// </summary> /// </summary>
public IRegularizer ActivityRegularizer { get; set; }
public virtual IRegularizer ActivityRegularizer { get; set; }


public bool Autocast { get; set; }
public virtual bool Autocast { get; set; }


public bool IsFromConfig { get; set; }
public virtual bool IsFromConfig { get; set; }
} }
} }

+ 4
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/NodeArgs.cs View File

@@ -1,6 +1,8 @@
namespace Tensorflow.Keras.ArgsDefinition
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class NodeArgs
public class NodeArgs: IKerasConfig
{ {
public ILayer[] InboundLayers { get; set; } public ILayer[] InboundLayers { get; set; }
public int[] NodeIndices { get; set; } public int[] NodeIndices { get; set; }


+ 4
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/OptimizerV2Args.cs View File

@@ -1,6 +1,8 @@
namespace Tensorflow.Keras.ArgsDefinition
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class OptimizerV2Args
public class OptimizerV2Args: IKerasConfig
{ {
public string Name { get; set; } public string Name { get; set; }
public float LearningRate { get; set; } = 0.001f; public float LearningRate { get; set; } = 0.001f;


+ 5
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/FlattenArgs.cs View File

@@ -1,7 +1,10 @@
namespace Tensorflow.Keras.ArgsDefinition
using Newtonsoft.Json;

namespace Tensorflow.Keras.ArgsDefinition
{ {
public class FlattenArgs : LayerArgs
public class FlattenArgs : AutoSerializeLayerArgs
{ {
[JsonProperty("data_format")]
public string DataFormat { get; set; } public string DataFormat { get; set; }
} }
} }

+ 50
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs View File

@@ -0,0 +1,50 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedActivationJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(Activation);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
var token = JToken.FromObject("");
token.WriteTo(writer);
}
else if (value is not Activation)
{
throw new TypeError($"Unable to use `CustomizedActivationJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var token = JToken.FromObject((value as Activation)!.GetType().Name);
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
throw new NotImplementedException();
//var dims = serializer.Deserialize(reader, typeof(string));
//if (dims is null)
//{
// throw new ValueError("Cannot deserialize 'null' to `Activation`.");
//}
//return new Shape((long[])(dims!));
}
}
}

+ 48
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedAxisJsonConverter.cs View File

@@ -0,0 +1,48 @@
using Newtonsoft.Json.Linq;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedAxisJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(Axis);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
var token = JToken.FromObject(new int[] { });
token.WriteTo(writer);
}
else if (value is not Axis)
{
throw new TypeError($"Unable to use `CustomizedAxisJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var token = JToken.FromObject((value as Axis)!.axis);
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var axis = serializer.Deserialize(reader, typeof(long[]));
if (axis is null)
{
throw new ValueError("Cannot deserialize 'null' to `Axis`.");
}
return new Axis((int[])(axis!));
}
}
}

+ 73
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedNodeConfigJsonConverter.cs View File

@@ -0,0 +1,73 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Keras.Saving;

namespace Tensorflow.Keras.Common
{
public class CustomizedNodeConfigJsonConverter : JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(NodeConfig);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if (value is null)
{
var token = JToken.FromObject(null);
token.WriteTo(writer);
}
else if (value is not NodeConfig)
{
throw new TypeError($"Unable to use `CustomizedNodeConfigJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var config = value as NodeConfig;
var token = JToken.FromObject(new object[] { config!.Name, config.NodeIndex, config.TensorIndex });
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var values = serializer.Deserialize(reader, typeof(object[])) as object[];
if (values is null)
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}
if(values.Length != 3)
{
throw new ValueError($"The value ({string.Join(", ", values)}) cannot be deserialized to type `NodeConfig`.");
}
if (values[0] is not string)
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `string`, but got `{values[0].GetType().Name}`");
}
if (values[1] is not int)
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[1].GetType().Name}`");
}
if (values[2] is not int)
{
throw new TypeError($"The first value of `NodeConfig` is expected to be `int`, but got `{values[2].GetType().Name}`");
}
return new NodeConfig()
{
Name = values[0] as string,
NodeIndex = (int)values[1],
TensorIndex = (int)values[2]
};
}
}
}

+ 67
- 0
src/TensorFlowNET.Core/Keras/Common/CustomizedShapeJsonConverter.cs View File

@@ -0,0 +1,67 @@
using Newtonsoft.Json;
using Newtonsoft.Json.Converters;
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Common
{
public class CustomizedShapeJsonConverter: JsonConverter
{
public override bool CanConvert(Type objectType)
{
return objectType == typeof(Shape);
}

public override bool CanRead => true;

public override bool CanWrite => true;

public override void WriteJson(JsonWriter writer, object? value, JsonSerializer serializer)
{
if(value is null)
{
var token = JToken.FromObject(null);
token.WriteTo(writer);
}
else if(value is not Shape)
{
throw new TypeError($"Unable to use `CustomizedShapeJsonConverter` to serialize the type {value.GetType()}.");
}
else
{
var shape = (value as Shape)!;
long?[] dims = new long?[shape.ndim];
for(int i = 0; i < dims.Length; i++)
{
if (shape.dims[i] == -1)
{
dims[i] = null;
}
else
{
dims[i] = shape.dims[i];
}
}
var token = JToken.FromObject(dims);
token.WriteTo(writer);
}
}

public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
var dims = serializer.Deserialize(reader, typeof(long?[])) as long?[];
if(dims is null)
{
throw new ValueError("Cannot deserialize 'null' to `Shape`.");
}
long[] convertedDims = new long[dims.Length];
for(int i = 0; i < dims.Length; i++)
{
convertedDims[i] = dims[i] ?? (-1);
}
return new Shape(convertedDims);
}
}
}

+ 30
- 1
src/TensorFlowNET.Core/Keras/Engine/InputSpec.cs View File

@@ -16,23 +16,27 @@


using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.Engine namespace Tensorflow.Keras.Engine
{ {
/// <summary> /// <summary>
/// Specifies the ndim, dtype and shape of every input to a layer. /// Specifies the ndim, dtype and shape of every input to a layer.
/// </summary> /// </summary>
public class InputSpec
public class InputSpec: IKerasConfigable
{ {
public int? ndim; public int? ndim;
public int? max_ndim;
public int? min_ndim; public int? min_ndim;
Dictionary<int, int> axes; Dictionary<int, int> axes;
Shape shape; Shape shape;
TF_DataType dtype;
public int[] AllAxisDim; public int[] AllAxisDim;


public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid, public InputSpec(TF_DataType dtype = TF_DataType.DtInvalid,
int? ndim = null, int? ndim = null,
int? min_ndim = null, int? min_ndim = null,
int? max_ndim = null,
Dictionary<int, int> axes = null, Dictionary<int, int> axes = null,
Shape shape = null) Shape shape = null)
{ {
@@ -41,7 +45,9 @@ namespace Tensorflow.Keras.Engine
axes = new Dictionary<int, int>(); axes = new Dictionary<int, int>();
this.axes = axes; this.axes = axes;
this.min_ndim = min_ndim; this.min_ndim = min_ndim;
this.max_ndim = max_ndim;
this.shape = shape; this.shape = shape;
this.dtype = dtype;
if (ndim == null && shape != null) if (ndim == null && shape != null)
this.ndim = shape.ndim; this.ndim = shape.ndim;


@@ -49,7 +55,30 @@ namespace Tensorflow.Keras.Engine
AllAxisDim = axes.Select(x => x.Value).ToArray(); AllAxisDim = axes.Select(x => x.Value).ToArray();
} }


public IKerasConfig get_config()
{
return new Config()
{
DType = dtype == TF_DataType.DtInvalid ? null : dtype,
Shape = shape,
Ndim = ndim,
MinNdim = min_ndim,
MaxNdim = max_ndim,
Axes = axes.ToDictionary(x => x.Key.ToString(), x => x.Value)
};
}

public override string ToString() public override string ToString()
=> $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}"; => $"ndim={ndim}, min_ndim={min_ndim}, axes={axes.Count}";

public class Config: IKerasConfig
{
public TF_DataType? DType { get; set; }
public Shape Shape { get; set; }
public int? Ndim { get; set; }
public int? MinNdim { get;set; }
public int? MaxNdim { get;set; }
public IDictionary<string, int> Axes { get; set; }
}
} }
} }

+ 3
- 2
src/TensorFlowNET.Core/Keras/Layers/ILayer.cs View File

@@ -1,11 +1,12 @@
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Training; using Tensorflow.Training;


namespace Tensorflow.Keras namespace Tensorflow.Keras
{ {
public interface ILayer: ITrackable
public interface ILayer: IWithTrackable, IKerasConfigable
{ {
string Name { get; } string Name { get; }
bool Trainable { get; } bool Trainable { get; }
@@ -19,8 +20,8 @@ namespace Tensorflow.Keras
List<IVariableV1> NonTrainableWeights { get; } List<IVariableV1> NonTrainableWeights { get; }
Shape OutputShape { get; } Shape OutputShape { get; }
Shape BatchInputShape { get; } Shape BatchInputShape { get; }
TensorShapeConfig BuildInputShape { get; }
TF_DataType DType { get; } TF_DataType DType { get; }
int count_params(); int count_params();
LayerArgs get_config();
} }
} }

+ 15
- 0
src/TensorFlowNET.Core/Keras/Saving/IKerasConfig.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Keras.Saving
{
public interface IKerasConfig
{
}

public interface IKerasConfigable
{
IKerasConfig get_config();
}
}

+ 7
- 2
src/TensorFlowNET.Core/Keras/Saving/LayerConfig.cs View File

@@ -1,4 +1,5 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
@@ -6,11 +7,15 @@ using Tensorflow.Keras.Engine;


namespace Tensorflow.Keras.Saving namespace Tensorflow.Keras.Saving
{ {
public class LayerConfig
public class LayerConfig: IKerasConfig
{ {
[JsonProperty("name")]
public string Name { get; set; } public string Name { get; set; }
[JsonProperty("class_name")]
public string ClassName { get; set; } public string ClassName { get; set; }
[JsonProperty("config")]
public LayerArgs Config { get; set; } public LayerArgs Config { get; set; }
[JsonProperty("inbound_nodes")]
public List<NodeConfig> InboundNodes { get; set; } public List<NodeConfig> InboundNodes { get; set; }
} }
} }

+ 7
- 2
src/TensorFlowNET.Core/Keras/Saving/ModelConfig.cs View File

@@ -1,15 +1,20 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;


namespace Tensorflow.Keras.Saving namespace Tensorflow.Keras.Saving
{ {
public class ModelConfig
public class ModelConfig : IKerasConfig
{ {
[JsonProperty("name")]
public string Name { get; set; } public string Name { get; set; }
[JsonProperty("layers")]
public List<LayerConfig> Layers { get; set; } public List<LayerConfig> Layers { get; set; }
[JsonProperty("input_layers")]
public List<NodeConfig> InputLayers { get; set; } public List<NodeConfig> InputLayers { get; set; }
[JsonProperty("output_layers")]
public List<NodeConfig> OutputLayers { get; set; } public List<NodeConfig> OutputLayers { get; set; }


public override string ToString() public override string ToString()


+ 5
- 2
src/TensorFlowNET.Core/Keras/Saving/NodeConfig.cs View File

@@ -1,10 +1,13 @@
using System;
using Newtonsoft.Json;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Keras.Common;


namespace Tensorflow.Keras.Saving namespace Tensorflow.Keras.Saving
{ {
public class NodeConfig
[JsonConverter(typeof(CustomizedNodeConfigJsonConverter))]
public class NodeConfig : IKerasConfig
{ {
public string Name { get; set; } public string Name { get; set; }
public int NodeIndex { get; set; } public int NodeIndex { get; set; }


+ 21
- 0
src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs View File

@@ -0,0 +1,21 @@
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras.Saving
{
public class TensorShapeConfig
{
[JsonProperty("class_name")]
public string ClassName { get; set; } = "TensorShape";
[JsonProperty("items")]
public long?[] Items { get; set; }

public static implicit operator Shape(TensorShapeConfig shape)
=> shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray());

public static implicit operator TensorShapeConfig(Shape shape)
=> new TensorShapeConfig() { Items = shape.dims.Select<long, long?>(x => x == -1 ? null : x).ToArray() };
}
}

+ 10
- 1
src/TensorFlowNET.Core/NumPy/Axis.cs View File

@@ -14,20 +14,29 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Newtonsoft.Json;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Keras.Common;


namespace Tensorflow namespace Tensorflow
{ {
public record Axis(params int[] axis)
[JsonConverter(typeof(CustomizedAxisJsonConverter))]
public class Axis
{ {
public int[] axis { get; set; }
public int size => axis == null ? -1 : axis.Length; public int size => axis == null ? -1 : axis.Length;
public bool IsScalar { get; init; } public bool IsScalar { get; init; }


public int this[int index] => axis[index]; public int this[int index] => axis[index];


public Axis(params int[] axis)
{
this.axis = axis;
}

public static implicit operator int[]?(Axis axis) public static implicit operator int[]?(Axis axis)
=> axis?.axis; => axis?.axis;




+ 3
- 0
src/TensorFlowNET.Core/Numpy/Shape.cs View File

@@ -14,14 +14,17 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Newtonsoft.Json;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Tensorflow.Keras.Common;
using Tensorflow.NumPy; using Tensorflow.NumPy;


namespace Tensorflow namespace Tensorflow
{ {
[JsonConverter(typeof(CustomizedShapeJsonConverter))]
public class Shape public class Shape
{ {
public int ndim => _dims == null ? -1 : _dims.Length; public int ndim => _dims == null ? -1 : _dims.Length;


+ 10
- 0
src/TensorFlowNET.Core/Operations/Initializers/Constant.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class Constant<T> : IInitializer public class Constant<T> : IInitializer
@@ -22,11 +24,19 @@ namespace Tensorflow.Operations.Initializers
T value; T value;
bool _verify_shape; bool _verify_shape;


private readonly Dictionary<string, object> _config;

public string ClassName => "Constant";
public IDictionary<string, object> Config => _config;

public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false) public Constant(T value, TF_DataType dtype = TF_DataType.TF_FLOAT, bool verify_shape = false)
{ {
this.value = value; this.value = value;
this.dtype = dtype; this.dtype = dtype;
_verify_shape = verify_shape; _verify_shape = verify_shape;

_config = new Dictionary<string, object>();
_config["value"] = this.value;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 9
- 1
src/TensorFlowNET.Core/Operations/Initializers/GlorotUniform.cs View File

@@ -14,10 +14,17 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class GlorotUniform : VarianceScaling public class GlorotUniform : VarianceScaling
{ {
private readonly Dictionary<string, object> _config;

public override string ClassName => "GlorotUniform";
public override IDictionary<string, object> Config => _config;

public GlorotUniform(float scale = 1.0f, public GlorotUniform(float scale = 1.0f,
string mode = "FAN_AVG", string mode = "FAN_AVG",
bool uniform = true, bool uniform = true,
@@ -28,7 +35,8 @@ namespace Tensorflow.Operations.Initializers
seed: seed, seed: seed,
dtype: dtype) dtype: dtype)
{ {

_config = new Dictionary<string, object>();
_config["seed"] = _seed;
} }
} }
} }

+ 7
- 0
src/TensorFlowNET.Core/Operations/Initializers/IInitializer.cs View File

@@ -14,10 +14,17 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Newtonsoft.Json;
using System.Collections.Generic;

namespace Tensorflow namespace Tensorflow
{ {
public interface IInitializer public interface IInitializer
{ {
[JsonProperty("class_name")]
string ClassName { get; }
[JsonProperty("config")]
IDictionary<string, object> Config { get; }
Tensor Apply(InitializerArgs args); Tensor Apply(InitializerArgs args);
} }
} }

+ 7
- 0
src/TensorFlowNET.Core/Operations/Initializers/Ones.cs View File

@@ -14,12 +14,19 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class Ones : IInitializer public class Ones : IInitializer
{ {
private TF_DataType dtype; private TF_DataType dtype;


private readonly Dictionary<string, object> _config;

public string ClassName => "Ones";
public IDictionary<string, object> Config => new Dictionary<string, object>();

public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT) public Ones(TF_DataType dtype = TF_DataType.TF_FLOAT)
{ {
this.dtype = dtype; this.dtype = dtype;


+ 5
- 0
src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs View File

@@ -1,9 +1,14 @@
using System; using System;
using System.Collections.Generic;


namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class Orthogonal : IInitializer public class Orthogonal : IInitializer
{ {
private readonly Dictionary<string, object> _config;

public string ClassName => "Orthogonal";
public IDictionary<string, object> Config => throw new NotImplementedException();
public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)
{ {
throw new NotImplementedException(); throw new NotImplementedException();


+ 12
- 0
src/TensorFlowNET.Core/Operations/Initializers/RandomNormal.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class RandomNormal : IInitializer public class RandomNormal : IInitializer
@@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers
private int? seed; private int? seed;
private TF_DataType dtype; private TF_DataType dtype;


private readonly Dictionary<string, object> _config;

public string ClassName => "RandomNormal";
public IDictionary<string, object> Config => _config;

public RandomNormal(float mean = 0.0f, public RandomNormal(float mean = 0.0f,
float stddev = 0.05f, float stddev = 0.05f,
int? seed = null, int? seed = null,
@@ -32,6 +39,11 @@ namespace Tensorflow.Operations.Initializers
this.stddev = stddev; this.stddev = stddev;
this.seed = seed; this.seed = seed;
this.dtype = dtype; this.dtype = dtype;

_config = new Dictionary<string, object>();
_config["mean"] = this.mean;
_config["stddev"] = this.stddev;
_config["seed"] = this.seed;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 12
- 0
src/TensorFlowNET.Core/Operations/Initializers/RandomUniform.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class RandomUniform : IInitializer public class RandomUniform : IInitializer
@@ -23,12 +25,22 @@ namespace Tensorflow.Operations.Initializers
private float maxval; private float maxval;
private TF_DataType dtype; private TF_DataType dtype;


private readonly Dictionary<string, object> _config;

public string ClassName => "RandomUniform";
public IDictionary<string, object> Config => _config;

public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null) public RandomUniform(TF_DataType dtype = TF_DataType.TF_FLOAT, float minval = -0.05f, float maxval = 0.05f, int? seed = null)
{ {
this.dtype = dtype; this.dtype = dtype;
this.minval = minval; this.minval = minval;
this.maxval = maxval; this.maxval = maxval;
this.seed = seed; this.seed = seed;

_config = new Dictionary<string, object>();
_config["minval"] = this.minval;
_config["maxval"] = this.maxval;
_config["seed"] = this.seed;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 11
- 0
src/TensorFlowNET.Core/Operations/Initializers/TruncatedNormal.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class TruncatedNormal : IInitializer public class TruncatedNormal : IInitializer
@@ -23,6 +25,11 @@ namespace Tensorflow.Operations.Initializers
private int? seed; private int? seed;
private TF_DataType dtype; private TF_DataType dtype;


private readonly Dictionary<string, object> _config;

public string ClassName => "TruncatedNormal";
public IDictionary<string, object> Config => _config;

public TruncatedNormal(float mean = 0.0f, public TruncatedNormal(float mean = 0.0f,
float stddev = 1.0f, float stddev = 1.0f,
int? seed = null, int? seed = null,
@@ -32,6 +39,10 @@ namespace Tensorflow.Operations.Initializers
this.stddev = stddev; this.stddev = stddev;
this.seed = seed; this.seed = seed;
this.dtype = dtype; this.dtype = dtype;
_config = new Dictionary<string, object>();
_config["mean"] = this.mean;
_config["stddev"] = this.stddev;
_config["seed"] = this.seed;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 13
- 0
src/TensorFlowNET.Core/Operations/Initializers/VarianceScaling.cs View File

@@ -15,7 +15,9 @@
******************************************************************************/ ******************************************************************************/


using System; using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Linq.Expressions;


namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
@@ -30,6 +32,11 @@ namespace Tensorflow.Operations.Initializers
protected int? _seed; protected int? _seed;
protected TF_DataType _dtype; protected TF_DataType _dtype;
protected bool _uniform; protected bool _uniform;
private readonly Dictionary<string, object> _config;

public virtual string ClassName => "VarianceScaling";

public virtual IDictionary<string, object> Config => _config;


public VarianceScaling(float factor = 2.0f, public VarianceScaling(float factor = 2.0f,
string mode = "FAN_IN", string mode = "FAN_IN",
@@ -50,6 +57,12 @@ namespace Tensorflow.Operations.Initializers
_seed = seed; _seed = seed;
_dtype = dtype; _dtype = dtype;
_uniform = uniform; _uniform = uniform;

_config = new();
_config["scale"] = _scale;
_config["mode"] = _mode;
_config["distribution"] = _distribution;
_config["seed"] = _seed;
} }


public Tensor Apply(InitializerArgs args) public Tensor Apply(InitializerArgs args)


+ 5
- 0
src/TensorFlowNET.Core/Operations/Initializers/Zeros.cs View File

@@ -14,6 +14,8 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using System.Collections.Generic;

namespace Tensorflow.Operations.Initializers namespace Tensorflow.Operations.Initializers
{ {
public class Zeros : IInitializer public class Zeros : IInitializer
@@ -21,6 +23,9 @@ namespace Tensorflow.Operations.Initializers
Shape shape; Shape shape;
TF_DataType dtype; TF_DataType dtype;


public string ClassName => "Zeros";
public IDictionary<string, object> Config => new Dictionary<string, object>();

public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT) public Zeros(Shape shape = null, TF_DataType dtype = TF_DataType.TF_FLOAT)
{ {
this.shape = shape; this.shape = shape;


+ 4
- 1
src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs View File

@@ -20,6 +20,7 @@ using Tensorflow.Keras;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Operations; using Tensorflow.Operations;
using Tensorflow.Train; using Tensorflow.Train;
using Tensorflow.Util; using Tensorflow.Util;
@@ -76,6 +77,8 @@ namespace Tensorflow


public Shape BatchInputShape => throw new NotImplementedException(); public Shape BatchInputShape => throw new NotImplementedException();


public TensorShapeConfig BuildInputShape => throw new NotImplementedException();

public TF_DataType DType => throw new NotImplementedException(); public TF_DataType DType => throw new NotImplementedException();
protected bool built = false; protected bool built = false;
public bool Built => built; public bool Built => built;
@@ -144,7 +147,7 @@ namespace Tensorflow
throw new NotImplementedException(); throw new NotImplementedException();
} }


public LayerArgs get_config()
public IKerasConfig get_config()
{ {
throw new NotImplementedException(); throw new NotImplementedException();
} }


+ 1
- 0
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -108,6 +108,7 @@ https://tensorflownet.readthedocs.io</Description>


<ItemGroup> <ItemGroup>
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
<PackageReference Include="Newtonsoft.Json" Version="13.0.2" />
<PackageReference Include="Protobuf.Text" Version="0.5.0" /> <PackageReference Include="Protobuf.Text" Version="0.5.0" />
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" /> <PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
</ItemGroup> </ItemGroup>


+ 18
- 0
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -202,6 +202,24 @@ namespace Tensorflow
_ => type.ToString() _ => type.ToString()
}; };


public static string as_python_name(this TF_DataType type)
=> type switch
{
TF_DataType.TF_STRING => "str",
TF_DataType.TF_UINT8 => "uint8",
TF_DataType.TF_INT8 => "int8",
TF_DataType.TF_UINT32 => "uint32",
TF_DataType.TF_INT32 => "int32",
TF_DataType.TF_UINT64 => "uint64",
TF_DataType.TF_INT64 => "int64",
TF_DataType.TF_FLOAT => "float32",
TF_DataType.TF_DOUBLE => "float64",
TF_DataType.TF_BOOL => "bool",
TF_DataType.TF_RESOURCE => "resource",
TF_DataType.TF_VARIANT => "variant",
_ => type.ToString()
};

public static int get_datatype_size(this TF_DataType type) public static int get_datatype_size(this TF_DataType type)
=> type.as_base_dtype() switch => type.as_base_dtype() switch
{ {


src/TensorFlowNET.Core/Training/ITrackable.cs → src/TensorFlowNET.Core/Training/IWithTrackable.cs View File

@@ -5,7 +5,7 @@ using Tensorflow.Train;


namespace Tensorflow.Training namespace Tensorflow.Training
{ {
public interface ITrackable
public interface IWithTrackable
{ {
Trackable GetTrackable(); Trackable GetTrackable();
} }

+ 1
- 1
src/TensorFlowNET.Core/Training/Trackable.cs View File

@@ -26,7 +26,7 @@ using static Tensorflow.Binding;


namespace Tensorflow.Train namespace Tensorflow.Train
{ {
public abstract class Trackable: ITrackable
public abstract class Trackable: IWithTrackable
{ {
/// <summary> /// <summary>
/// Corresponding to tensorflow/python/trackable/constants.py /// Corresponding to tensorflow/python/trackable/constants.py


+ 17
- 14
src/TensorFlowNET.Keras/Engine/Functional.GetConfig.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow.Keras.Engine
{ {
public partial class Functional public partial class Functional
{ {
public ModelConfig get_config()
public override IKerasConfig get_config()
{ {
return get_network_config(); return get_network_config();
} }
@@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine
{ {
Name = name Name = name
}; };
var node_conversion_map = new Dictionary<string, int>(); var node_conversion_map = new Dictionary<string, int>();
foreach (var layer in _self_tracked_trackables) foreach (var layer in _self_tracked_trackables)
{ {
@@ -42,23 +42,26 @@ namespace Tensorflow.Keras.Engine
} }


var layer_configs = new List<LayerConfig>(); var layer_configs = new List<LayerConfig>();
foreach (var layer in _self_tracked_trackables)
using (SharedObjectSavingScope.Enter())
{ {
var filtered_inbound_nodes = new List<NodeConfig>();
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
foreach (var layer in _self_tracked_trackables)
{ {
var node_key = _make_node_key(layer.Name, original_node_index);
if (NetworkNodes.Contains(node_key) && !node.is_input)
var filtered_inbound_nodes = new List<NodeConfig>();
foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
{ {
var node_data = node.serialize(_make_node_key, node_conversion_map);
filtered_inbound_nodes.append(node_data);
var node_key = _make_node_key(layer.Name, original_node_index);
if (NetworkNodes.Contains(node_key) && !node.is_input)
{
var node_data = node.serialize(_make_node_key, node_conversion_map);
filtered_inbound_nodes.append(node_data);
}
} }
}


var layer_config = generic_utils.serialize_layer_to_config(layer);
layer_config.Name = layer.Name;
layer_config.InboundNodes = filtered_inbound_nodes;
layer_configs.Add(layer_config);
var layer_config = generic_utils.serialize_layer_to_config(layer);
layer_config.Name = layer.Name;
layer_config.InboundNodes = filtered_inbound_nodes;
layer_configs.Add(layer_config);
}
} }
config.Layers = layer_configs; config.Layers = layer_configs;




+ 18
- 0
src/TensorFlowNET.Keras/Engine/Functional.cs View File

@@ -70,6 +70,7 @@ namespace Tensorflow.Keras.Engine
this.inputs = inputs; this.inputs = inputs;
this.outputs = outputs; this.outputs = outputs;
built = true; built = true;
_buildInputShape = inputs.shape;


if (outputs.Any(x => x.KerasHistory == null)) if (outputs.Any(x => x.KerasHistory == null))
base_layer_utils.create_keras_history(outputs); base_layer_utils.create_keras_history(outputs);
@@ -357,5 +358,22 @@ namespace Tensorflow.Keras.Engine
return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache)) return LayerCheckpointDependencies.ToDictionary(x => x.Key, x => x.Value.GetTrackable()).Concat(base._trackable_children(save_type, cache))
.ToDictionary(x => x.Key, x => x.Value); .ToDictionary(x => x.Key, x => x.Value);
} }

protected override void _init_set_name(string name, bool zero_based = true)
{
if (string.IsNullOrEmpty(name))
{
string class_name = GetType().Name;
if (this.GetType() == typeof(Functional))
{
class_name = "Model";
}
this.name = base_layer_utils.unique_layer_name(generic_utils.to_snake_case(class_name), zero_based: zero_based);
}
else
{
this.name = name;
}
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -61,6 +61,7 @@ namespace Tensorflow.Keras.Engine
/// Provides information about which inputs are compatible with the layer. /// Provides information about which inputs are compatible with the layer.
/// </summary> /// </summary>
protected InputSpec inputSpec; protected InputSpec inputSpec;
public InputSpec InputSpec => inputSpec;
bool dynamic = true; bool dynamic = true;
public bool SupportsMasking { get; set; } public bool SupportsMasking { get; set; }
protected List<IVariableV1> _trainable_weights; protected List<IVariableV1> _trainable_weights;
@@ -79,6 +80,8 @@ namespace Tensorflow.Keras.Engine
protected bool computePreviousMask; protected bool computePreviousMask;
protected List<Operation> updates; protected List<Operation> updates;
public Shape BatchInputShape => args.BatchInputShape; public Shape BatchInputShape => args.BatchInputShape;
protected TensorShapeConfig _buildInputShape = null;
public TensorShapeConfig BuildInputShape => _buildInputShape;


List<INode> inboundNodes; List<INode> inboundNodes;
public List<INode> InboundNodes => inboundNodes; public List<INode> InboundNodes => inboundNodes;
@@ -223,6 +226,7 @@ namespace Tensorflow.Keras.Engine


public virtual void build(Shape input_shape) public virtual void build(Shape input_shape)
{ {
_buildInputShape = input_shape;
built = true; built = true;
} }


@@ -310,7 +314,7 @@ namespace Tensorflow.Keras.Engine


public List<IVariableV1> Variables => weights; public List<IVariableV1> Variables => weights;


public virtual LayerArgs get_config()
public virtual IKerasConfig get_config()
=> args; => args;
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Engine/Model.Save.cs View File

@@ -1,6 +1,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using Tensorflow.Functions; using Tensorflow.Functions;
using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.Saving.SavedModel; using Tensorflow.Keras.Saving.SavedModel;
using Tensorflow.ModelSaving; using Tensorflow.ModelSaving;


@@ -30,7 +31,10 @@ namespace Tensorflow.Keras.Engine
} }
else else
{ {
KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces);
using (SharedObjectSavingScope.Enter())
{
KerasSavedModelUtils.Save(this, filepath, overwrite, include_optimizer, signatures, options, save_traces);
}
} }
} }
} }


+ 1
- 0
src/TensorFlowNET.Keras/Layers/Activation/ELU.cs View File

@@ -25,6 +25,7 @@ namespace Tensorflow.Keras.Layers {
{ {
throw new ValueError("Alpha must be a number greater than 0."); throw new ValueError("Alpha must be a number greater than 0.");
} }
_buildInputShape = input_shape;
built = true; built = true;
} }




+ 1
- 0
src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs View File

@@ -14,6 +14,7 @@ namespace Tensorflow.Keras.Layers {
} }
public override void build(Shape input_shape) public override void build(Shape input_shape)
{ {
_buildInputShape = input_shape;
built = true; built = true;
} }
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)


+ 5
- 4
src/TensorFlowNET.Keras/Layers/Activation/SELU.cs View File

@@ -16,10 +16,11 @@ namespace Tensorflow.Keras.Layers {
// SELU has no arguments // SELU has no arguments
} }
public override void build(Shape input_shape) { public override void build(Shape input_shape) {
if ( alpha < 0f ) {
throw new ValueError("Alpha must be a number greater than 0.");
}
built = true;
if ( alpha < 0f ) {
throw new ValueError("Alpha must be a number greater than 0.");
}
_buildInputShape = input_shape;
built = true;
} }
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
Tensor output = inputs; Tensor output = inputs;


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Attention/Attention.cs View File

@@ -4,6 +4,7 @@ using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.Layers namespace Tensorflow.Keras.Layers
{ {
@@ -146,7 +147,7 @@ namespace Tensorflow.Keras.Layers
return scores; return scores;
} }


public override LayerArgs get_config() => this.args;
public override IKerasConfig get_config() => this.args;
//var config = new Dictionary<object, object> { //var config = new Dictionary<object, object> {
// { // {
// "use_scale", // "use_scale",


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Attention/BaseDenseAttention.cs View File

@@ -5,6 +5,7 @@ using static Tensorflow.KerasApi;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using Tensorflow.Keras.Saving;


/// <summary> /// <summary>
/// Base class for attention layers that can be used in sequence DNN/CNN models. /// Base class for attention layers that can be used in sequence DNN/CNN models.
@@ -252,6 +253,6 @@ namespace Tensorflow.Keras.Layers
return tf.logical_and(x, y); return tf.logical_and(x, y);
} }


public override LayerArgs get_config() => this.args;
public override IKerasConfig get_config() => this.args;
} }
} }

+ 1
- 0
src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs View File

@@ -49,6 +49,7 @@ namespace Tensorflow.Keras.Layers
initializer: bias_initializer, initializer: bias_initializer,
trainable: true); trainable: true);
built = true; built = true;
_buildInputShape = input_shape;
} }


protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)


+ 1
- 0
src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs View File

@@ -98,6 +98,7 @@ namespace Tensorflow.Keras.Layers
name: tf_op_name); name: tf_op_name);


built = true; built = true;
_buildInputShape = input_shape;
} }


protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = false)


+ 1
- 0
src/TensorFlowNET.Keras/Layers/Core/Dense.cs View File

@@ -43,6 +43,7 @@ namespace Tensorflow.Keras.Layers


public override void build(Shape input_shape) public override void build(Shape input_shape)
{ {
_buildInputShape = input_shape;
var last_dim = input_shape.dims.Last(); var last_dim = input_shape.dims.Last();
var axes = new Dictionary<int, int>(); var axes = new Dictionary<int, int>();
axes[-1] = (int)last_dim; axes[-1] = (int)last_dim;


+ 1
- 0
src/TensorFlowNET.Keras/Layers/Core/Embedding.cs View File

@@ -62,6 +62,7 @@ namespace Tensorflow.Keras.Layers
name: "embeddings"); name: "embeddings");
tf.Context.graph_mode(); tf.Context.graph_mode();
built = true; built = true;
_buildInputShape = input_shape;
} }


protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)


+ 1
- 0
src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs View File

@@ -22,6 +22,7 @@ namespace Tensorflow.Keras.Layers {
throw new ValueError("The `cropping` argument must be a tuple of 2 integers."); throw new ValueError("The `cropping` argument must be a tuple of 2 integers.");
} }
built = true; built = true;
_buildInputShape = input_shape;
} }


protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs View File

@@ -13,7 +13,8 @@ namespace Tensorflow.Keras.Layers {
this.args = args; this.args = args;
} }
public override void build(Shape input_shape) { public override void build(Shape input_shape) {
built = true;
built = true;
_buildInputShape = input_shape;
} }
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
Tensor output = inputs; Tensor output = inputs;


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs View File

@@ -12,7 +12,8 @@ namespace Tensorflow.Keras.Layers {
} }


public override void build(Shape input_shape) { public override void build(Shape input_shape) {
built = true;
built = true;
_buildInputShape = input_shape;
} }


protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) { protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {


+ 3
- 1
src/TensorFlowNET.Keras/Layers/LayersApi.cs View File

@@ -300,7 +300,8 @@ namespace Tensorflow.Keras.Layers
=> new Dense(new DenseArgs => new Dense(new DenseArgs
{ {
Units = units, Units = units,
Activation = GetActivationByName("linear")
Activation = GetActivationByName("linear"),
ActivationName = "linear"
}); });


/// <summary> /// <summary>
@@ -321,6 +322,7 @@ namespace Tensorflow.Keras.Layers
{ {
Units = units, Units = units,
Activation = GetActivationByName(activation), Activation = GetActivationByName(activation),
ActivationName = activation,
InputShape = input_shape InputShape = input_shape
}); });




+ 1
- 0
src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs View File

@@ -37,6 +37,7 @@ namespace Tensorflow.Keras.Layers
}).ToArray(); }).ToArray();
shape_set.Add(shape); shape_set.Add(shape);
}*/ }*/
_buildInputShape = input_shape;
} }


protected override Tensors _merge_function(Tensors inputs) protected override Tensors _merge_function(Tensors inputs)


+ 1
- 0
src/TensorFlowNET.Keras/Layers/Merging/Merge.cs View File

@@ -17,6 +17,7 @@ namespace Tensorflow.Keras.Layers
public override void build(Shape input_shape) public override void build(Shape input_shape)
{ {
// output_shape = input_shape.dims[1^]; // output_shape = input_shape.dims[1^];
_buildInputShape = input_shape;
} }


protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)


+ 1
- 0
src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs View File

@@ -118,6 +118,7 @@ namespace Tensorflow.Keras.Layers
throw new NotImplementedException("build when renorm is true"); throw new NotImplementedException("build when renorm is true");


built = true; built = true;
_buildInputShape = input_shape;
} }


public override Shape ComputeOutputShape(Shape input_shape) public override Shape ComputeOutputShape(Shape input_shape)


+ 1
- 0
src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs View File

@@ -81,6 +81,7 @@ namespace Tensorflow.Keras.Layers
_fused = _fused_can_be_used(ndims); _fused = _fused_can_be_used(ndims);


built = true; built = true;
_buildInputShape = input_shape;
} }


bool _fused_can_be_used(int ndims) bool _fused_can_be_used(int ndims)


+ 1
- 0
src/TensorFlowNET.Keras/Layers/Reshaping/Permute.cs View File

@@ -24,6 +24,7 @@ namespace Tensorflow.Keras.Layers {
permute = new int[input_shape.rank]; permute = new int[input_shape.rank];
dims.CopyTo(permute, 1); dims.CopyTo(permute, 1);
built = true; built = true;
_buildInputShape = input_shape;
} }
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null) protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{ {


+ 1
- 0
src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs View File

@@ -18,6 +18,7 @@ namespace Tensorflow.Keras.Layers.Rnn
public override void build(Shape input_shape) public override void build(Shape input_shape)
{ {
var input_dim = input_shape[-1]; var input_dim = input_shape[-1];
_buildInputShape = input_shape;


kernel = add_weight("kernel", (input_shape[-1], args.Units), kernel = add_weight("kernel", (input_shape[-1], args.Units),
initializer: args.KernelInitializer initializer: args.KernelInitializer


+ 2
- 1
src/TensorFlowNET.Keras/Layers/Rnn/StackedRNNCells.cs View File

@@ -4,6 +4,7 @@ using System.ComponentModel;
using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn; using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.Layers.Rnn namespace Tensorflow.Keras.Layers.Rnn
{ {
@@ -136,7 +137,7 @@ namespace Tensorflow.Keras.Layers.Rnn
// self.built = True // self.built = True
} }


public override LayerArgs get_config()
public override IKerasConfig get_config()
{ {
throw new NotImplementedException(); throw new NotImplementedException();
//def get_config(self): //def get_config(self):


+ 1
- 1
src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs View File

@@ -79,7 +79,7 @@ public partial class KerasSavedModelUtils


var path = node_paths[node]; var path = node_paths[node];
string node_path; string node_path;
if (path is null)
if (path is null || path.Count() == 0)
{ {
node_path = "root"; node_path = "root";
} }


+ 21
- 12
src/TensorFlowNET.Keras/Saving/SavedModel/layer_serialization.cs View File

@@ -1,6 +1,7 @@
using System.Collections.Generic; using System.Collections.Generic;
using Newtonsoft.Json; using Newtonsoft.Json;
using Newtonsoft.Json.Linq; using Newtonsoft.Json.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Layers; using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Utils; using Tensorflow.Keras.Utils;
@@ -85,31 +86,38 @@ public class LayerSavedModelSaver: SavedModelSaver
JObject metadata = new JObject(); JObject metadata = new JObject();
metadata["name"] = _layer.Name; metadata["name"] = _layer.Name;
metadata["trainable"] = _layer.Trainable; metadata["trainable"] = _layer.Trainable;
// metadata["expects_training_arg"] = _obj._expects_training_arg;
// metadata["dtype"] = policy.serialize(_obj._dtype_policy)
// TODO: implement `expects_training_arg`.
metadata["expects_training_arg"] = false;
metadata["dtype"] = _layer.DType.as_python_name();
metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape); metadata["batch_input_shape"] = _layer.BatchInputShape is null ? null : JToken.FromObject(_layer.BatchInputShape);
// metadata["stateful"] = _obj.stateful; // metadata["stateful"] = _obj.stateful;
// metadata["must_restore_from_config"] = _obj.must_restore_from_config; // metadata["must_restore_from_config"] = _obj.must_restore_from_config;
// metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config; // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config;
metadata["autocast"] = _layer.AutoCast; metadata["autocast"] = _layer.AutoCast;


var temp = JObject.FromObject(get_serialized(_layer));
metadata.Merge(temp, new JsonMergeSettings
if(_layer.InputSpec is not null)
{
metadata["input_spec"] = generic_utils.serialize_keras_object(_layer.InputSpec);
}

metadata.Merge(get_serialized(_layer), new JsonMergeSettings
{ {
// Handle conflicts by using values from obj2 // Handle conflicts by using values from obj2
MergeArrayHandling = MergeArrayHandling.Merge MergeArrayHandling = MergeArrayHandling.Merge
}); });
// skip the check of `input_spec` and `build_input_shape` for the lack of members. // skip the check of `input_spec` and `build_input_shape` for the lack of members.
// skip the check of `activity_regularizer` for the type problem. // skip the check of `activity_regularizer` for the type problem.
if(_layer.BuildInputShape is not null)
{
metadata["build_input_shape"] = JToken.FromObject(_layer.BuildInputShape);
}
return metadata.ToString(); return metadata.ToString();
} }
} }


public static IDictionary<string, object> get_serialized(Layer obj)
public static JObject get_serialized(Layer obj)
{ {
// TODO: complete the implmentation (need to revise `get_config`).
return new Dictionary<string, object>();
//return generic_utils.serialize_keras_object(obj);
return generic_utils.serialize_keras_object(obj);
} }
} }


@@ -135,18 +143,19 @@ public class InputLayerSavedModelSaver: SavedModelSaver
{ {
get get
{ {
if(_obj is not Layer)
if(_obj is not InputLayer)
{ {
throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer."); throw new TypeError($"The type {_obj.GetType()} cannot be recognized as an input layer.");
} }
var layer = (Layer)_obj;
var layer = (InputLayer)_obj;
var config = (layer.get_config() as InputLayerArgs)!;
var info = new var info = new
{ {
class_name = layer.GetType().Name, class_name = layer.GetType().Name,
name = layer.Name, name = layer.Name,
dtype = layer.DType, dtype = layer.DType,
//sparse = layer.sparse,
//ragged = layer.ragged,
sparse = config.Sparse,
ragged = config.Ragged,
batch_input_shape = layer.BatchInputShape, batch_input_shape = layer.BatchInputShape,
config = layer.get_config() config = layer.get_config()
}; };


+ 0
- 15
src/TensorFlowNET.Keras/Saving/TensorShapeConfig.cs View File

@@ -1,15 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;

namespace Tensorflow.Keras.Saving
{
public class TensorShapeConfig
{
public string ClassName { get; set; }
public int?[] Items { get; set; }

public static implicit operator Shape(TensorShapeConfig shape)
=> shape == null ? null : new Shape(shape.Items.Select(x => x.HasValue ? x.Value : -1).ToArray());
}
}

+ 125
- 0
src/TensorFlowNET.Keras/Saving/serialization.cs View File

@@ -0,0 +1,125 @@
using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using Tensorflow.Keras.Saving.SavedModel;

namespace Tensorflow.Keras.Saving
{
// TODO: make it thread safe.
public class SharedObjectSavingScope: IDisposable
{
private class WeakReferenceEqualityComparer: IEqualityComparer<WeakReference<object>>
{
public bool Equals(WeakReference<object> x, WeakReference<object> y)
{
if(!x.TryGetTarget(out var tx))
{
return false;
}
if(!y.TryGetTarget(out var ty))
{
return false;
}
return tx.Equals(ty);
}
public int GetHashCode(WeakReference<object> obj)
{
if (!obj.TryGetTarget(out var w))
{
return 0;
}
return w.GetHashCode();
}
}
private static SharedObjectSavingScope? _instance = null;
private readonly Dictionary<WeakReference<object>, int> _shared_object_ids= new Dictionary<WeakReference<object>, int>();
private int _currentId = 0;
/// <summary>
/// record how many times the scope is nested.
/// </summary>
private int _nestedDepth = 0;
private SharedObjectSavingScope()
{

}

public static SharedObjectSavingScope Enter()
{
if(_instance is not null)
{
_instance._nestedDepth++;
return _instance;
}
else
{
_instance = new SharedObjectSavingScope();
_instance._nestedDepth++;
return _instance;
}
}

public static SharedObjectSavingScope GetScope()
{
return _instance;
}

public int GetId(object? obj)
{
if(obj is null)
{
return _currentId++;
}
var maybe_key = _shared_object_ids.Keys.SingleOrDefault(x => new WeakReferenceEqualityComparer().Equals(x, new WeakReference<object>(obj)));
if (maybe_key is not null)
{
return _shared_object_ids[maybe_key];
}
_shared_object_ids[new WeakReference<object>(obj)] = _currentId++;
return _currentId;
}

public void Dispose()
{
_nestedDepth--;
if(_nestedDepth== 0)
{
_instance = null;
}
}
}

public static class serialize_utils
{
public static readonly string SHARED_OBJECT_KEY = "shared_object_id";
/// <summary>
/// Returns the serialization of the class with the given config.
/// </summary>
/// <param name="class_name"></param>
/// <param name="config"></param>
/// <param name="obj"></param>
/// <param name="shared_object_id"></param>
/// <returns></returns>
public static JObject serialize_keras_class_and_config(string class_name, JToken config, object? obj = null, int? shared_object_id = null)
{
JObject res = new JObject();
res["class_name"] = class_name;
res["config"] = config;

if(shared_object_id is not null)
{
res[SHARED_OBJECT_KEY] = shared_object_id!;
}

var scope = SharedObjectSavingScope.GetScope();
if(scope is not null && obj is not null)
{
res[SHARED_OBJECT_KEY] = scope.GetId(obj);
}

return res;
}
}
}

+ 1
- 1
src/TensorFlowNET.Keras/Utils/base_layer_utils.cs View File

@@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Utils
} }


/// <summary> /// <summary>
/// Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
/// Makes a layer name (or arbitrary string) unique within a TensorFlow graph. (correponding to `backend.unique_object_name` of python.)
/// </summary> /// </summary>
/// <param name="name"></param> /// <param name="name"></param>
/// <returns></returns> /// <returns></returns>


+ 13
- 1
src/TensorFlowNET.Keras/Utils/generic_utils.cs View File

@@ -14,10 +14,14 @@
limitations under the License. limitations under the License.
******************************************************************************/ ******************************************************************************/


using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using System; using System;
using System.Collections; using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics;
using System.Linq; using System.Linq;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Saving; using Tensorflow.Keras.Saving;


namespace Tensorflow.Keras.Utils namespace Tensorflow.Keras.Utils
@@ -32,13 +36,21 @@ namespace Tensorflow.Keras.Utils
public static LayerConfig serialize_layer_to_config(ILayer instance) public static LayerConfig serialize_layer_to_config(ILayer instance)
{ {
var config = instance.get_config(); var config = instance.get_config();
Debug.Assert(config is LayerArgs);
return new LayerConfig return new LayerConfig
{ {
Config = config,
Config = config as LayerArgs,
ClassName = instance.GetType().Name ClassName = instance.GetType().Name
}; };
} }


public static JObject serialize_keras_object(IKerasConfigable instance)
{
var config = JToken.FromObject(instance.get_config());
// TODO: change the class_name to registered name, instead of system class name.
return serialize_utils.serialize_keras_class_and_config(instance.GetType().Name, config, instance);
}

public static string to_snake_case(string name) public static string to_snake_case(string name)
{ {
return string.Concat(name.Select((x, i) => return string.Concat(name.Select((x, i) =>


+ 4
- 1
test/TensorFlowNET.Keras.UnitTest/Layers/ModelSaveTest.cs View File

@@ -1,6 +1,8 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using System.Diagnostics;
using static Tensorflow.KerasApi; using static Tensorflow.KerasApi;
using Tensorflow.Keras.Saving;


namespace TensorFlowNET.Keras.UnitTest namespace TensorFlowNET.Keras.UnitTest
{ {
@@ -15,7 +17,8 @@ namespace TensorFlowNET.Keras.UnitTest
{ {
var model = GetFunctionalModel(); var model = GetFunctionalModel();
var config = model.get_config(); var config = model.get_config();
var new_model = keras.models.from_config(config);
Debug.Assert(config is ModelConfig);
var new_model = keras.models.from_config(config as ModelConfig);
Assert.AreEqual(model.Layers.Count, new_model.Layers.Count); Assert.AreEqual(model.Layers.Count, new_model.Layers.Count);
} }




+ 31
- 9
test/TensorFlowNET.Keras.UnitTest/SaveTest.cs View File

@@ -15,17 +15,14 @@ using Tensorflow.Keras.Layers;
using Tensorflow.Keras.Losses; using Tensorflow.Keras.Losses;
using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Metrics;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using Tensorflow.Operations;


namespace TensorFlowNET.Keras.UnitTest; namespace TensorFlowNET.Keras.UnitTest;


// class MNISTLoader
// {
// public MNISTLoader()
// {
// var mnist = new MnistModelLoader()
//
// }
// }
public static class AutoGraphExtension
{
}


[TestClass] [TestClass]
public class SaveTest public class SaveTest
@@ -42,6 +39,8 @@ public class SaveTest
model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"});


var g = ops.get_default_graph();

var data_loader = new MnistModelLoader(); var data_loader = new MnistModelLoader();
var num_epochs = 1; var num_epochs = 1;
var batch_size = 50; var batch_size = 50;
@@ -50,11 +49,34 @@ public class SaveTest
{ {
TrainDir = "mnist", TrainDir = "mnist",
OneHot = false, OneHot = false,
ValidationSize = 0,
ValidationSize = 50000,
}).Result; }).Result;
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb");
} }

[TestMethod]
public void Temp()
{
var graph = new Graph();
var g = graph.as_default();
//var input_tensor = array_ops.placeholder(TF_DataType.TF_FLOAT, new int[] { 1 }, "test_string_tensor");
var input_tensor = tf.placeholder(tf.int32, new int[] { 1 }, "aa");
var wrapped_func = tf.autograph.to_graph(func);
var res = wrapped_func(input_tensor);
g.Exit();
}

private Tensor func(Tensor tensor)
{
return gen_ops.neg(tensor);
//return array_ops.identity(tensor);
//tf.device("cpu:0");
//using (ops.control_dependencies(new object[] { res.op }))
//{
// return array_ops.identity(tensor);
//}
}
} }

Loading…
Cancel
Save