diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs
index 23552316..e830e5bf 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/ELUArgs.cs
@@ -1,9 +1,12 @@
-using System;
+using Newtonsoft.Json;
+using System;
using System.Collections.Generic;
using System.Text;
namespace Tensorflow.Keras.ArgsDefinition {
- public class ELUArgs : LayerArgs {
- public float Alpha { get; set; } = 0.1f;
- }
+ public class ELUArgs : AutoSerializeLayerArgs
+ {
+ [JsonProperty("alpha")]
+ public float Alpha { get; set; } = 0.1f;
+ }
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs
index 6bdb294c..6d953134 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/LeakyReLuArgs.cs
@@ -1,14 +1,16 @@
-using System;
+using Newtonsoft.Json;
+using System;
using System.Collections.Generic;
using System.Text;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class LeakyReLuArgs : LayerArgs
+ public class LeakyReLuArgs : AutoSerializeLayerArgs
{
///
/// Negative slope coefficient.
///
+ [JsonProperty("alpha")]
public float Alpha { get; set; } = 0.3f;
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs
index a37973bc..1c1d147f 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Activation/SoftmaxArgs.cs
@@ -4,15 +4,9 @@ using System.Collections.Generic;
using System.Text;
namespace Tensorflow.Keras.ArgsDefinition {
- public class SoftmaxArgs : LayerArgs
+ public class SoftmaxArgs : AutoSerializeLayerArgs
{
[JsonProperty("axis")]
public Axis axis { get; set; } = -1;
- [JsonProperty("name")]
- public override string Name { get => base.Name; set => base.Name = value; }
- [JsonProperty("trainable")]
- public override bool Trainable { get => base.Trainable; set => base.Trainable = value; }
- [JsonProperty("dtype")]
- public override TF_DataType DType { get => base.DType; set => base.DType = value; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs
index 73477c58..4cdfb46b 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/AttentionArgs.cs
@@ -1,3 +1,5 @@
+using Newtonsoft.Json;
+
namespace Tensorflow.Keras.ArgsDefinition
{
public class AttentionArgs : BaseDenseAttentionArgs
@@ -6,6 +8,7 @@ namespace Tensorflow.Keras.ArgsDefinition
///
/// If `true`, will create a scalar variable to scale the attention scores.
///
+ [JsonProperty("use_scale")]
public bool use_scale { get; set; } = false;
///
@@ -14,6 +17,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// and key vectors. `"concat"` refers to the hyperbolic tangent of the
/// concatenation of the query and key vectors.
///
+ [JsonProperty("score_mode")]
public string score_mode { get; set; } = "dot";
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs
index b2a0c3a5..0ef01737 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/BaseDenseAttentionArgs.cs
@@ -1,6 +1,8 @@
+using Newtonsoft.Json;
+
namespace Tensorflow.Keras.ArgsDefinition
{
- public class BaseDenseAttentionArgs : LayerArgs
+ public class BaseDenseAttentionArgs : AutoSerializeLayerArgs
{
///
@@ -14,6 +16,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// Float between 0 and 1. Fraction of the units to drop for the
/// attention scores.
///
+ [JsonProperty("dropout")]
public float dropout { get; set; } = 0f;
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs
index 21b2d218..077dea89 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/MultiHeadAttentionArgs.cs
@@ -1,22 +1,40 @@
+using Newtonsoft.Json;
using System;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class MultiHeadAttentionArgs : LayerArgs
+ public class MultiHeadAttentionArgs : AutoSerializeLayerArgs
{
+ [JsonProperty("num_heads")]
public int NumHeads { get; set; }
+ [JsonProperty("key_dim")]
public int KeyDim { get; set; }
+ [JsonProperty("value_dim")]
public int? ValueDim { get; set; } = null;
+ [JsonProperty("dropout")]
public float Dropout { get; set; } = 0f;
+ [JsonProperty("use_bias")]
public bool UseBias { get; set; } = true;
+ [JsonProperty("output_shape")]
public Shape OutputShape { get; set; } = null;
+ [JsonProperty("attention_axes")]
public Shape AttentionAxis { get; set; } = null;
+ [JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
+ [JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
+ [JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; } = null;
+ [JsonProperty("bias_regularizer")]
public IRegularizer BiasRegularizer { get; set; } = null;
+ [JsonProperty("kernel_constraint")]
public Action KernelConstraint { get; set; } = null;
+ [JsonProperty("bias_constraint")]
public Action BiasConstraint { get; set; } = null;
+ [JsonProperty("activity_regularizer")]
+ public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; }
+
+ // TODO: Add `key_shape`, `value_shape`, `query_shape`.
}
}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
index 66b34a1a..1a97b013 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/AutoSerializeLayerArgs.cs
@@ -5,6 +5,12 @@ using System.Text;
namespace Tensorflow.Keras.ArgsDefinition
{
+ ///
+ /// This class has nothing but the attributes different from `LayerArgs`.
+ /// It's used to serialize the model to `tf` format.
+ /// If the `get_config` of a `Layer` in python code of tensorflow contains `super().get_config`,
+ /// then the Arg definition should inherit `utoSerializeLayerArgs` instead of `LayerArgs`.
+ ///
public class AutoSerializeLayerArgs: LayerArgs
{
[JsonProperty("name")]
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs
index 4f050228..08d563c1 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs
@@ -1,31 +1,65 @@
-using System;
+using Newtonsoft.Json;
+using System;
using static Tensorflow.Binding;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class ConvolutionalArgs : LayerArgs
+ public class ConvolutionalArgs : AutoSerializeLayerArgs
{
public int Rank { get; set; } = 2;
+ [JsonProperty("filters")]
public int Filters { get; set; }
public int NumSpatialDims { get; set; } = Unknown;
+ [JsonProperty("kernel_size")]
public Shape KernelSize { get; set; } = 5;
///
/// specifying the stride length of the convolution.
///
+ [JsonProperty("strides")]
public Shape Strides { get; set; } = (1, 1);
-
+ [JsonProperty("padding")]
public string Padding { get; set; } = "valid";
+ [JsonProperty("data_format")]
public string DataFormat { get; set; }
+ [JsonProperty("dilation_rate")]
public Shape DilationRate { get; set; } = (1, 1);
+ [JsonProperty("groups")]
public int Groups { get; set; } = 1;
public Activation Activation { get; set; }
+ private string _activationName;
+ [JsonProperty("activation")]
+ public string ActivationName
+ {
+ get
+ {
+ if (string.IsNullOrEmpty(_activationName))
+ {
+ return Activation.Method.Name;
+ }
+ else
+ {
+ return _activationName;
+ }
+ }
+ set
+ {
+ _activationName = value;
+ }
+ }
+ [JsonProperty("use_bias")]
public bool UseBias { get; set; }
+ [JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
+ [JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
+ [JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; }
+ [JsonProperty("bias_regularizer")]
public IRegularizer BiasRegularizer { get; set; }
+ [JsonProperty("kernel_constraint")]
public Action KernelConstraint { get; set; }
+ [JsonProperty("bias_constraint")]
public Action BiasConstraint { get; set; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/EinsumDenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs
similarity index 65%
rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/EinsumDenseArgs.cs
rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs
index 3a8642ff..9817e9c6 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Attention/EinsumDenseArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs
@@ -1,9 +1,10 @@
+using Newtonsoft.Json;
using System;
using static Tensorflow.Binding;
-namespace Tensorflow.Keras.ArgsDefinition
+namespace Tensorflow.Keras.ArgsDefinition.Core
{
- public class EinsumDenseArgs : LayerArgs
+ public class EinsumDenseArgs : AutoSerializeLayerArgs
{
///
/// An equation describing the einsum to perform. This equation must
@@ -11,6 +12,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis
/// expression sequence.
///
+ [JsonProperty("equation")]
public string Equation { get; set; }
///
@@ -19,6 +21,7 @@ namespace Tensorflow.Keras.ArgsDefinition
/// None for any dimension that is unknown or can be inferred from the input
/// shape.
///
+ [JsonProperty("output_shape")]
public Shape OutputShape { get; set; }
///
@@ -26,41 +29,70 @@ namespace Tensorflow.Keras.ArgsDefinition
/// Each character in the `bias_axes` string should correspond to a character
/// in the output portion of the `equation` string.
///
+ [JsonProperty("bias_axes")]
public string BiasAxes { get; set; } = null;
///
/// Activation function to use.
///
public Activation Activation { get; set; }
+ private string _activationName;
+ [JsonProperty("activation")]
+ public string ActivationName
+ {
+ get
+ {
+ if (string.IsNullOrEmpty(_activationName))
+ {
+ return Activation.Method.Name;
+ }
+ else
+ {
+ return _activationName;
+ }
+ }
+ set
+ {
+ _activationName = value;
+ }
+ }
///
/// Initializer for the `kernel` weights matrix.
///
+ [JsonProperty("kernel_initializer")]
public IInitializer KernelInitializer { get; set; } = tf.glorot_uniform_initializer;
///
/// Initializer for the bias vector.
///
+ [JsonProperty("bias_initializer")]
public IInitializer BiasInitializer { get; set; } = tf.zeros_initializer;
///
/// Regularizer function applied to the `kernel` weights matrix.
///
+ [JsonProperty("kernel_regularizer")]
public IRegularizer KernelRegularizer { get; set; }
///
/// Regularizer function applied to the bias vector.
///
+ [JsonProperty("bias_regularizer")]
public IRegularizer BiasRegularizer { get; set; }
///
/// Constraint function applied to the `kernel` weights matrix.
///
+ [JsonProperty("kernel_constraint")]
public Action KernelConstraint { get; set; }
///
/// Constraint function applied to the bias vector.
///
+ [JsonProperty("bias_constraint")]
public Action BiasConstraint { get; set; }
+ [JsonProperty("activity_regularizer")]
+ public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs
index b1f4fddd..c462961b 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EmbeddingArgs.cs
@@ -1,11 +1,22 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Newtonsoft.Json;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
- public class EmbeddingArgs : LayerArgs
+ public class EmbeddingArgs : AutoSerializeLayerArgs
{
+ [JsonProperty("input_dim")]
public int InputDim { get; set; }
+ [JsonProperty("output_dim")]
public int OutputDim { get; set; }
+ [JsonProperty("mask_zero")]
public bool MaskZero { get; set; }
+ [JsonProperty("input_length")]
public int InputLength { get; set; } = -1;
+ [JsonProperty("embeddings_initializer")]
public IInitializer EmbeddingsInitializer { get; set; }
+ [JsonProperty("activity_regularizer")]
+ public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; }
+
+ // TODO: `embeddings_regularizer`, `embeddings_constraint`.
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs
deleted file mode 100644
index 16705063..00000000
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping2DArgs.cs
+++ /dev/null
@@ -1,16 +0,0 @@
-using Tensorflow.NumPy;
-
-namespace Tensorflow.Keras.ArgsDefinition {
- public class Cropping2DArgs : LayerArgs {
- ///
- /// channel last: (b, h, w, c)
- /// channels_first: (b, c, h, w)
- ///
- public enum DataFormat { channels_first = 0, channels_last = 1 }
- ///
- /// Accept: int[1][2], int[1][1], int[2][2]
- ///
- public NDArray cropping { get; set; }
- public DataFormat data_format { get; set; } = DataFormat.channels_last;
- }
-}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs
deleted file mode 100644
index 9da2adc7..00000000
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/Cropping3DArgs.cs
+++ /dev/null
@@ -1,16 +0,0 @@
-using Tensorflow.NumPy;
-
-namespace Tensorflow.Keras.ArgsDefinition {
- public class Cropping3DArgs : LayerArgs {
- ///
- /// channel last: (b, h, w, c)
- /// channels_first: (b, c, h, w)
- ///
- public enum DataFormat { channels_first = 0, channels_last = 1 }
- ///
- /// Accept: int[1][3], int[1][1], int[3][2]
- ///
- public NDArray cropping { get; set; }
- public DataFormat data_format { get; set; } = DataFormat.channels_last;
- }
-}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs
deleted file mode 100644
index 9d23acd4..00000000
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Cropping/CroppingArgs.cs
+++ /dev/null
@@ -1,10 +0,0 @@
-using Tensorflow.NumPy;
-
-namespace Tensorflow.Keras.ArgsDefinition {
- public class CroppingArgs : LayerArgs {
- ///
- /// Accept length 1 or 2
- ///
- public NDArray cropping { get; set; }
- }
-}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs
deleted file mode 100644
index fb0868dc..00000000
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMCellArgs.cs
+++ /dev/null
@@ -1,6 +0,0 @@
-namespace Tensorflow.Keras.ArgsDefinition.Lstm
-{
- public class LSTMCellArgs : LayerArgs
- {
- }
-}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs
index 3e6791e3..0140b3dd 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Merging/MergeArgs.cs
@@ -4,6 +4,7 @@ using System.Text;
namespace Tensorflow.Keras.ArgsDefinition
{
+ // TODO: complete the implementation
public class MergeArgs : LayerArgs
{
public Tensors Inputs { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs
index 954ede57..6ee91e80 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/BatchNormalizationArgs.cs
@@ -1,21 +1,37 @@
-using static Tensorflow.Binding;
+using Newtonsoft.Json;
+using static Tensorflow.Binding;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class BatchNormalizationArgs : LayerArgs
+ public class BatchNormalizationArgs : AutoSerializeLayerArgs
{
+ [JsonProperty("axis")]
public Shape Axis { get; set; } = -1;
+ [JsonProperty("momentum")]
public float Momentum { get; set; } = 0.99f;
+ [JsonProperty("epsilon")]
public float Epsilon { get; set; } = 1e-3f;
+ [JsonProperty("center")]
public bool Center { get; set; } = true;
+ [JsonProperty("scale")]
public bool Scale { get; set; } = true;
+ [JsonProperty("beta_initializer")]
public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer;
+ [JsonProperty("gamma_initializer")]
public IInitializer GammaInitializer { get; set; } = tf.ones_initializer;
+ [JsonProperty("moving_mean_initializer")]
public IInitializer MovingMeanInitializer { get; set; } = tf.zeros_initializer;
+ [JsonProperty("moving_variance_initializer")]
public IInitializer MovingVarianceInitializer { get; set; } = tf.ones_initializer;
+ [JsonProperty("beta_regularizer")]
public IRegularizer BetaRegularizer { get; set; }
+ [JsonProperty("gamma_regularizer")]
public IRegularizer GammaRegularizer { get; set; }
+ // TODO: `beta_constraint` and `gamma_constraint`.
+ [JsonProperty("renorm")]
public bool Renorm { get; set; }
+ // TODO: `renorm_clipping` and `virtual_batch_size`.
+ [JsonProperty("renorm_momentum")]
public float RenormMomentum { get; set; } = 0.99f;
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs
index 13fd98b4..1ac661b3 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Normalization/LayerNormalizationArgs.cs
@@ -1,16 +1,27 @@
-using static Tensorflow.Binding;
+using Newtonsoft.Json;
+using static Tensorflow.Binding;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class LayerNormalizationArgs : LayerArgs
+ public class LayerNormalizationArgs : AutoSerializeLayerArgs
{
+ [JsonProperty("axis")]
public Axis Axis { get; set; } = -1;
+ [JsonProperty("epsilon")]
public float Epsilon { get; set; } = 1e-3f;
+ [JsonProperty("center")]
public bool Center { get; set; } = true;
+ [JsonProperty("scale")]
public bool Scale { get; set; } = true;
+ [JsonProperty("beta_initializer")]
public IInitializer BetaInitializer { get; set; } = tf.zeros_initializer;
+ [JsonProperty("gamma_initializer")]
public IInitializer GammaInitializer { get; set; } = tf.ones_initializer;
+ [JsonProperty("beta_regularizer")]
public IRegularizer BetaRegularizer { get; set; }
+ [JsonProperty("gamma_regularizer")]
public IRegularizer GammaRegularizer { get; set; }
+
+ // TODO: `beta_constraint` and `gamma_constraint`.
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs
index 9742203d..c5fdca67 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling1DArgs.cs
@@ -1,6 +1,8 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Newtonsoft.Json;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
- public class Pooling1DArgs : LayerArgs
+ public class Pooling1DArgs : AutoSerializeLayerArgs
{
///
/// The pooling function to apply, e.g. `tf.nn.max_pool2d`.
@@ -10,11 +12,13 @@
///
/// specifying the size of the pooling window.
///
+ [JsonProperty("pool_size")]
public int PoolSize { get; set; }
///
/// specifying the strides of the pooling operation.
///
+ [JsonProperty("strides")]
public int Strides {
get { return _strides.HasValue ? _strides.Value : PoolSize; }
set { _strides = value; }
@@ -24,11 +28,13 @@
///
/// The padding method, either 'valid' or 'same'.
///
+ [JsonProperty("padding")]
public string Padding { get; set; } = "valid";
///
/// one of `channels_last` (default) or `channels_first`.
///
+ [JsonProperty("data_format")]
public string DataFormat { get; set; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs
index 1260af4c..91a372ef 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Pooling/Pooling2DArgs.cs
@@ -1,6 +1,8 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Newtonsoft.Json;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
- public class Pooling2DArgs : LayerArgs
+ public class Pooling2DArgs : AutoSerializeLayerArgs
{
///
/// The pooling function to apply, e.g. `tf.nn.max_pool2d`.
@@ -10,21 +12,25 @@
///
/// specifying the size of the pooling window.
///
+ [JsonProperty("pool_size")]
public Shape PoolSize { get; set; }
///
/// specifying the strides of the pooling operation.
///
+ [JsonProperty("strides")]
public Shape Strides { get; set; }
///
/// The padding method, either 'valid' or 'same'.
///
+ [JsonProperty("padding")]
public string Padding { get; set; } = "valid";
///
/// one of `channels_last` (default) or `channels_first`.
///
+ [JsonProperty("data_format")]
public string DataFormat { get; set; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs
index 28ccf9f7..97cb364d 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/PreprocessingLayerArgs.cs
@@ -4,7 +4,7 @@ using System.Text;
namespace Tensorflow.Keras.ArgsDefinition
{
- public class PreprocessingLayerArgs : LayerArgs
+ public class PreprocessingLayerArgs : AutoSerializeLayerArgs
{
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs
new file mode 100644
index 00000000..154bd8c8
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/RescalingArgs.cs
@@ -0,0 +1,12 @@
+using Newtonsoft.Json;
+
+namespace Tensorflow.Keras.ArgsDefinition
+{
+ public class RescalingArgs : AutoSerializeLayerArgs
+ {
+ [JsonProperty("scale")]
+ public float Scale { get; set; }
+ [JsonProperty("offset")]
+ public float Offset { get; set; }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs
index cf11595e..39fa5221 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/ResizingArgs.cs
@@ -1,5 +1,6 @@
namespace Tensorflow.Keras.ArgsDefinition
{
+ // TODO: no corresponding class found in keras python, maybe obselete?
public class ResizingArgs : PreprocessingLayerArgs
{
public int Height { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs
index ddeadc00..1a7149f5 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Preprocessing/TextVectorizationArgs.cs
@@ -1,4 +1,5 @@
-using System;
+using Newtonsoft.Json;
+using System;
using System.Collections.Generic;
using System.Text;
@@ -6,11 +7,19 @@ namespace Tensorflow.Keras.ArgsDefinition
{
public class TextVectorizationArgs : PreprocessingLayerArgs
{
+ [JsonProperty("standardize")]
public Func Standardize { get; set; }
+ [JsonProperty("split")]
public string Split { get; set; } = "standardize";
+ [JsonProperty("max_tokens")]
public int MaxTokens { get; set; } = -1;
+ [JsonProperty("output_mode")]
public string OutputMode { get; set; } = "int";
+ [JsonProperty("output_sequence_length")]
public int OutputSequenceLength { get; set; } = -1;
+ [JsonProperty("vocabulary")]
public string[] Vocabulary { get; set; }
+
+ // TODO: Add `ngrams`, `sparse`, `ragged`, `idf_weights`, `encoding`
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs
index c41c6fe8..1c85d493 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Regularization/DropoutArgs.cs
@@ -1,21 +1,26 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Newtonsoft.Json;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
- public class DropoutArgs : LayerArgs
+ public class DropoutArgs : AutoSerializeLayerArgs
{
///
/// Float between 0 and 1. Fraction of the input units to drop.
///
+ [JsonProperty("rate")]
public float Rate { get; set; }
///
/// 1D integer tensor representing the shape of the
/// binary dropout mask that will be multiplied with the input.
///
+ [JsonProperty("noise_shape")]
public Shape NoiseShape { get; set; }
///
/// random seed.
///
+ [JsonProperty("seed")]
public int? Seed { get; set; }
public bool SupportsMasking { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs
deleted file mode 100644
index ec9b5315..00000000
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rescaling/RescalingArgs.cs
+++ /dev/null
@@ -1,8 +0,0 @@
-namespace Tensorflow.Keras.ArgsDefinition
-{
- public class RescalingArgs : LayerArgs
- {
- public float Scale { get; set; }
- public float Offset { get; set; }
- }
-}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs
new file mode 100644
index 00000000..8c262639
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping2DArgs.cs
@@ -0,0 +1,18 @@
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.ArgsDefinition.Reshaping
+{
+ public class Cropping2DArgs : LayerArgs
+ {
+ ///
+ /// channel last: (b, h, w, c)
+ /// channels_first: (b, c, h, w)
+ ///
+ public enum DataFormat { channels_first = 0, channels_last = 1 }
+ ///
+ /// Accept: int[1][2], int[1][1], int[2][2]
+ ///
+ public NDArray cropping { get; set; }
+ public DataFormat data_format { get; set; } = DataFormat.channels_last;
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs
new file mode 100644
index 00000000..2d98e55d
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/Cropping3DArgs.cs
@@ -0,0 +1,18 @@
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.ArgsDefinition.Reshaping
+{
+ public class Cropping3DArgs : LayerArgs
+ {
+ ///
+ /// channel last: (b, h, w, c)
+ /// channels_first: (b, c, h, w)
+ ///
+ public enum DataFormat { channels_first = 0, channels_last = 1 }
+ ///
+ /// Accept: int[1][3], int[1][1], int[3][2]
+ ///
+ public NDArray cropping { get; set; }
+ public DataFormat data_format { get; set; } = DataFormat.channels_last;
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs
new file mode 100644
index 00000000..21b85966
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/CroppingArgs.cs
@@ -0,0 +1,12 @@
+using Tensorflow.NumPy;
+
+namespace Tensorflow.Keras.ArgsDefinition.Reshaping
+{
+ public class Cropping1DArgs : LayerArgs
+ {
+ ///
+ /// Accept length 1 or 2
+ ///
+ public NDArray cropping { get; set; }
+ }
+}
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs
index 2686f6cd..92be10ab 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/PermuteArgs.cs
@@ -1,5 +1,9 @@
-namespace Tensorflow.Keras.ArgsDefinition {
- public class PermuteArgs : LayerArgs {
- public int[] dims { get; set; }
- }
+using Newtonsoft.Json;
+
+namespace Tensorflow.Keras.ArgsDefinition {
+ public class PermuteArgs : AutoSerializeLayerArgs
+ {
+ [JsonProperty("dims")]
+ public int[] dims { get; set; }
+ }
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs
index 77bca8ad..4d1123c8 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ReshapeArgs.cs
@@ -1,7 +1,10 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Newtonsoft.Json;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
- public class ReshapeArgs : LayerArgs
+ public class ReshapeArgs : AutoSerializeLayerArgs
{
+ [JsonProperty("target_shape")]
public Shape TargetShape { get; set; }
public object[] TargetShapeObjects { get; set; }
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs
index 7fdda32d..b35e0e4b 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/UpSampling2DArgs.cs
@@ -1,12 +1,17 @@
-namespace Tensorflow.Keras.ArgsDefinition
+using Newtonsoft.Json;
+
+namespace Tensorflow.Keras.ArgsDefinition
{
- public class UpSampling2DArgs : LayerArgs
+ public class UpSampling2DArgs : AutoSerializeLayerArgs
{
+ [JsonProperty("size")]
public Shape Size { get; set; }
+ [JsonProperty("data_format")]
public string DataFormat { get; set; }
///
/// 'nearest', 'bilinear'
///
+ [JsonProperty("interpolation")]
public string Interpolation { get; set; } = "nearest";
}
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs
index ed6e7cc9..4831e435 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Reshaping/ZeroPadding2DArgs.cs
@@ -2,6 +2,7 @@
namespace Tensorflow.Keras.ArgsDefinition
{
+ // TODO: complete the implementation
public class ZeroPadding2DArgs : LayerArgs
{
public NDArray Padding { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
similarity index 67%
rename from src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs
rename to src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
index b08d21d8..76464147 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Lstm/LSTMArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMArgs.cs
@@ -1,9 +1,8 @@
-using Tensorflow.Keras.ArgsDefinition.Rnn;
-
-namespace Tensorflow.Keras.ArgsDefinition.Lstm
+namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
public class LSTMArgs : RNNArgs
{
+ // TODO: maybe change the `RNNArgs` and implement this class.
public bool UnitForgetBias { get; set; }
public float Dropout { get; set; }
public float RecurrentDropout { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
new file mode 100644
index 00000000..594c99bb
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs
@@ -0,0 +1,7 @@
+namespace Tensorflow.Keras.ArgsDefinition.Rnn
+{
+ // TODO: complete the implementation
+ public class LSTMCellArgs : LayerArgs
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
index da527925..2585592c 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs
@@ -1,21 +1,30 @@
-using System.Collections.Generic;
+using Newtonsoft.Json;
+using System.Collections.Generic;
namespace Tensorflow.Keras.ArgsDefinition.Rnn
{
- public class RNNArgs : LayerArgs
+ public class RNNArgs : AutoSerializeLayerArgs
{
public interface IRnnArgCell : ILayer
{
object state_size { get; }
}
-
+ [JsonProperty("cell")]
+ // TODO: the cell should be serialized with `serialize_keras_object`.
public IRnnArgCell Cell { get; set; } = null;
+ [JsonProperty("return_sequences")]
public bool ReturnSequences { get; set; } = false;
+ [JsonProperty("return_state")]
public bool ReturnState { get; set; } = false;
+ [JsonProperty("go_backwards")]
public bool GoBackwards { get; set; } = false;
+ [JsonProperty("stateful")]
public bool Stateful { get; set; } = false;
+ [JsonProperty("unroll")]
public bool Unroll { get; set; } = false;
+ [JsonProperty("time_major")]
public bool TimeMajor { get; set; } = false;
+ // TODO: Add `num_constants` and `zero_output_for_mask`.
public Dictionary Kwargs { get; set; } = null;
public int Units { get; set; }
diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs
index 602e7a88..3578652e 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.Cropping.cs
@@ -1,5 +1,5 @@
using System;
-using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.ArgsDefinition.Reshaping;
using Tensorflow.NumPy;
namespace Tensorflow.Keras.Layers
diff --git a/src/TensorFlowNET.Keras/Activations.cs b/src/TensorFlowNET.Keras/Activations.cs
new file mode 100644
index 00000000..444c783e
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Activations.cs
@@ -0,0 +1,82 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Operations.Activation;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Keras
+{
+ public class Activations
+ {
+ private static Dictionary _nameActivationMap;
+ private static Dictionary _activationNameMap;
+
+ private static Activation _linear = (features, name) => features;
+ private static Activation _relu = (features, name)
+ => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features));
+ private static Activation _sigmoid = (features, name)
+ => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features));
+ private static Activation _softmax = (features, name)
+ => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features));
+ private static Activation _tanh = (features, name)
+ => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features));
+
+ ///
+ /// Register the name-activation mapping in this static class.
+ ///
+ ///
+ ///
+ private static void RegisterActivation(string name, Activation activation)
+ {
+ _nameActivationMap[name] = activation;
+ _activationNameMap[activation] = name;
+ }
+
+ static Activations()
+ {
+ _nameActivationMap = new Dictionary();
+ _activationNameMap= new Dictionary();
+
+ RegisterActivation("relu", _relu);
+ RegisterActivation("linear", _linear);
+ RegisterActivation("sigmoid", _sigmoid);
+ RegisterActivation("softmax", _softmax);
+ RegisterActivation("tanh", _tanh);
+ }
+
+ public Activation Linear => _linear;
+
+ public Activation Relu => _relu;
+
+ public Activation Sigmoid => _sigmoid;
+
+ public Activation Softmax => _softmax;
+
+ public Activation Tanh => _tanh;
+
+
+ public static Activation GetActivationByName(string name)
+ {
+ if (!_nameActivationMap.TryGetValue(name, out var res))
+ {
+ throw new Exception($"Activation {name} not found");
+ }
+ else
+ {
+ return res;
+ }
+ }
+
+ public static string GetNameByActivation(Activation activation)
+ {
+ if(!_activationNameMap.TryGetValue(activation, out var name))
+ {
+ throw new Exception($"Activation {activation} not found");
+ }
+ else
+ {
+ return name;
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Linear.cs b/src/TensorFlowNET.Keras/Activations/Activations.Linear.cs
deleted file mode 100644
index acd4de6e..00000000
--- a/src/TensorFlowNET.Keras/Activations/Activations.Linear.cs
+++ /dev/null
@@ -1,10 +0,0 @@
-namespace Tensorflow.Keras
-{
- public partial class Activations
- {
- ///
- /// Linear activation function (pass-through).
- ///
- public Activation Linear = (features, name) => features;
- }
-}
diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs b/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs
deleted file mode 100644
index dfebfb29..00000000
--- a/src/TensorFlowNET.Keras/Activations/Activations.Relu.cs
+++ /dev/null
@@ -1,10 +0,0 @@
-using static Tensorflow.Binding;
-
-namespace Tensorflow.Keras
-{
- public partial class Activations
- {
- public Activation Relu = (features, name)
- => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features));
- }
-}
diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs b/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs
deleted file mode 100644
index ad900bde..00000000
--- a/src/TensorFlowNET.Keras/Activations/Activations.Sigmoid.cs
+++ /dev/null
@@ -1,11 +0,0 @@
-using System;
-using static Tensorflow.Binding;
-
-namespace Tensorflow.Keras
-{
- public partial class Activations
- {
- public Activation Sigmoid = (features, name)
- => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features));
- }
-}
diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Softmax.cs b/src/TensorFlowNET.Keras/Activations/Activations.Softmax.cs
deleted file mode 100644
index 02d86ace..00000000
--- a/src/TensorFlowNET.Keras/Activations/Activations.Softmax.cs
+++ /dev/null
@@ -1,11 +0,0 @@
-using System;
-using static Tensorflow.Binding;
-
-namespace Tensorflow.Keras
-{
- public partial class Activations
- {
- public Activation Softmax = (features, name)
- => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features));
- }
-}
diff --git a/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs b/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs
deleted file mode 100644
index 33dc5ba6..00000000
--- a/src/TensorFlowNET.Keras/Activations/Activations.Tanh.cs
+++ /dev/null
@@ -1,11 +0,0 @@
-using System;
-using static Tensorflow.Binding;
-
-namespace Tensorflow.Keras
-{
- public partial class Activations
- {
- public Activation Tanh = (features, name)
- => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features));
- }
-}
diff --git a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs
index 1b82e0a9..701724d5 100644
--- a/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs
+++ b/src/TensorFlowNET.Keras/Layers/Attention/MultiHeadAttention.cs
@@ -1,4 +1,5 @@
using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.ArgsDefinition.Core;
using Tensorflow.Keras.Engine;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
index 0f387570..af71ddf9 100644
--- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
+++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
@@ -4,8 +4,8 @@ using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
-using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
+using Tensorflow.Keras.ArgsDefinition.Core;
namespace Tensorflow.Keras.Layers
{
diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs
deleted file mode 100644
index 1f33ee3a..00000000
--- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs
+++ /dev/null
@@ -1,114 +0,0 @@
-using Tensorflow.Keras.ArgsDefinition;
-using Tensorflow.Keras.Engine;
-
-namespace Tensorflow.Keras.Layers {
- ///
- /// Crop the input along axis 1 and 2.
- /// For example:
- /// shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5)
- ///
- public class Cropping2D : Layer {
- Cropping2DArgs args;
- public Cropping2D ( Cropping2DArgs args ) : base(args) {
- this.args = args;
- }
- public override void build(Shape input_shape) {
- built = true;
- _buildInputShape = input_shape;
- }
- protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
- Tensor output = inputs;
- if ( output.rank != 4 ) {
- // throw an ValueError exception
- throw new ValueError("Expected dim=4, found dim=" + output.rank);
- }
- if ( args.cropping.shape == new Shape(1) ) {
- int crop = args.cropping[0];
- if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
- output = output[new Slice(),
- new Slice(crop, ( int ) output.shape[1] - crop),
- new Slice(crop, ( int ) output.shape[2] - crop),
- new Slice()];
- }
- else {
- output = output[new Slice(),
- new Slice(),
- new Slice(crop, ( int ) output.shape[2] - crop),
- new Slice(crop, ( int ) output.shape[3] - crop)];
- }
- }
- // a tuple of 2 integers
- else if ( args.cropping.shape == new Shape(2) ) {
- int crop_1 = args.cropping[0];
- int crop_2 = args.cropping[1];
- if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
- output = output[new Slice(),
- new Slice(crop_1, ( int ) output.shape[1] - crop_1),
- new Slice(crop_2, ( int ) output.shape[2] - crop_2),
- new Slice()];
- }
- else {
- output = output[new Slice(),
- new Slice(),
- new Slice(crop_1, ( int ) output.shape[2] - crop_1),
- new Slice(crop_2, ( int ) output.shape[3] - crop_2)];
- }
- }
- else if ( args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2 ) {
- int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1];
- int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1];
- if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
- output = output[new Slice(),
- new Slice(x_start, ( int ) output.shape[1] - x_end),
- new Slice(y_start, ( int ) output.shape[2] - y_end),
- new Slice()];
- }
- else {
- output = output[new Slice(),
- new Slice(),
- new Slice(x_start, ( int ) output.shape[2] - x_end),
- new Slice(y_start, ( int ) output.shape[3] - y_end)
- ];
- }
- }
- return output;
- }
-
- public override Shape ComputeOutputShape ( Shape input_shape ) {
- if ( args.cropping.shape == new Shape(1) ) {
- int crop = args.cropping[0];
- if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop * 2, ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3]);
- }
- else {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3] - crop * 2);
- }
- }
- // a tuple of 2 integers
- else if ( args.cropping.shape == new Shape(2) ) {
- int crop_1 = args.cropping[0], crop_2 = args.cropping[1];
- if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_1 * 2, ( int ) input_shape[2] - crop_2 * 2, ( int ) input_shape[3]);
- }
- else {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop_1 * 2, ( int ) input_shape[3] - crop_2 * 2);
- }
- }
- else if ( args.cropping.shape == new Shape(2, 2) ) {
- int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1];
- int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1];
- if ( args.data_format == Cropping2DArgs.DataFormat.channels_last ) {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_1_start - crop_1_end,
- ( int ) input_shape[2] - crop_2_start - crop_2_end, ( int ) input_shape[3]);
- }
- else {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1],
- ( int ) input_shape[2] - crop_1_start - crop_1_end, ( int ) input_shape[3] - crop_2_start - crop_2_end);
- }
- }
- else {
- throw new ValueError();
- }
- }
- }
-}
diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs
deleted file mode 100644
index 838a5043..00000000
--- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs
+++ /dev/null
@@ -1,124 +0,0 @@
-using Tensorflow.Keras.ArgsDefinition;
-using Tensorflow.Keras.Engine;
-
-namespace Tensorflow.Keras.Layers {
- ///
- /// Similar to copping 2D
- ///
- public class Cropping3D : Layer {
- Cropping3DArgs args;
- public Cropping3D ( Cropping3DArgs args ) : base(args) {
- this.args = args;
- }
-
- public override void build(Shape input_shape) {
- built = true;
- _buildInputShape = input_shape;
- }
-
- protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
- Tensor output = inputs;
- if ( output.rank != 5 ) {
- // throw an ValueError exception
- throw new ValueError("Expected dim=5, found dim=" + output.rank);
- }
-
- if ( args.cropping.shape == new Shape(1) ) {
- int crop = args.cropping[0];
- if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) {
- output = output[new Slice(),
- new Slice(crop, ( int ) output.shape[1] - crop),
- new Slice(crop, ( int ) output.shape[2] - crop),
- new Slice(crop, ( int ) output.shape[3] - crop),
- new Slice()];
- }
- else {
- output = output[new Slice(),
- new Slice(),
- new Slice(crop, ( int ) output.shape[2] - crop),
- new Slice(crop, ( int ) output.shape[3] - crop),
- new Slice(crop, ( int ) output.shape[4] - crop)];
- }
-
- }
- // int[1][3] equivalent to a tuple of 3 integers
- else if ( args.cropping.shape == new Shape(3) ) {
- var crop_1 = args.cropping[0];
- var crop_2 = args.cropping[1];
- var crop_3 = args.cropping[2];
- if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) {
- output = output[new Slice(),
- new Slice(crop_1, ( int ) output.shape[1] - crop_1),
- new Slice(crop_2, ( int ) output.shape[2] - crop_2),
- new Slice(crop_3, ( int ) output.shape[3] - crop_3),
- new Slice()];
- }
- else {
- output = output[new Slice(),
- new Slice(),
- new Slice(crop_1, ( int ) output.shape[2] - crop_1),
- new Slice(crop_2, ( int ) output.shape[3] - crop_2),
- new Slice(crop_3, ( int ) output.shape[4] - crop_3)];
- }
- }
- else if ( args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2 ) {
- int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
- int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
- int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
- if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) {
- output = output[new Slice(),
- new Slice(x, ( int ) output.shape[1] - x_end),
- new Slice(y, ( int ) output.shape[2] - y_end),
- new Slice(z, ( int ) output.shape[3] - z_end),
- new Slice()];
- }
- else {
- output = output[new Slice(),
- new Slice(),
- new Slice(x, ( int ) output.shape[2] - x_end),
- new Slice(y, ( int ) output.shape[3] - y_end),
- new Slice(z, ( int ) output.shape[4] - z_end)
- ];
- }
- }
- return output;
- }
- public override Shape ComputeOutputShape ( Shape input_shape ) {
- if ( args.cropping.shape == new Shape(1) ) {
- int crop = args.cropping[0];
- if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop * 2, ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3] - crop * 2, ( int ) input_shape[4]);
- }
- else {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop * 2, ( int ) input_shape[3] - crop * 2, ( int ) input_shape[4] - crop * 2);
- }
- }
- // int[1][3] equivalent to a tuple of 3 integers
- else if ( args.cropping.shape == new Shape(3) ) {
- var crop_start_1 = args.cropping[0];
- var crop_start_2 = args.cropping[1];
- var crop_start_3 = args.cropping[2];
- if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - crop_start_1 * 2, ( int ) input_shape[2] - crop_start_2 * 2, ( int ) input_shape[3] - crop_start_3 * 2, ( int ) input_shape[4]);
- }
- else {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - crop_start_1 * 2, ( int ) input_shape[3] - crop_start_2 * 2, ( int ) input_shape[4] - crop_start_3 * 2);
- }
- }
- else if ( args.cropping.shape == new Shape(3, 2) ) {
- int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
- int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
- int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
- if ( args.data_format == Cropping3DArgs.DataFormat.channels_last ) {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1] - x - x_end, ( int ) input_shape[2] - y - y_end, ( int ) input_shape[3] - z - z_end, ( int ) input_shape[4]);
- }
- else {
- return new Shape(( int ) input_shape[0], ( int ) input_shape[1], ( int ) input_shape[2] - x - x_end, ( int ) input_shape[3] - y - y_end, ( int ) input_shape[4] - z - z_end);
- }
- }
- else {
- throw new ValueError();
- }
- }
- }
-}
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs
index 339ddb85..3e3442f2 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.Cropping.cs
@@ -2,16 +2,18 @@
using System;
using System.Collections.Generic;
using System.Text;
-using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.Layers.Reshaping;
+using Tensorflow.Keras.ArgsDefinition.Reshaping;
-namespace Tensorflow.Keras.Layers {
- public partial class LayersApi {
+namespace Tensorflow.Keras.Layers
+{
+ public partial class LayersApi {
///
/// Cropping layer for 1D input
///
/// cropping size
public ILayer Cropping1D ( NDArray cropping )
- => new Cropping1D(new CroppingArgs {
+ => new Cropping1D(new Cropping1DArgs {
cropping = cropping
});
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index 769beea0..76634918 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -1,9 +1,8 @@
using System;
using Tensorflow.Keras.ArgsDefinition;
-using Tensorflow.Keras.ArgsDefinition.Lstm;
+using Tensorflow.Keras.ArgsDefinition.Core;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
-using Tensorflow.Keras.Layers.Lstm;
using Tensorflow.Keras.Layers.Rnn;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
@@ -108,7 +107,7 @@ namespace Tensorflow.Keras.Layers
DilationRate = dilation_rate,
Groups = groups,
UseBias = use_bias,
- Activation = GetActivationByName(activation),
+ Activation = Activations.GetActivationByName(activation),
KernelInitializer = GetInitializerByName(kernel_initializer),
BiasInitializer = GetInitializerByName(bias_initializer)
});
@@ -163,7 +162,7 @@ namespace Tensorflow.Keras.Layers
BiasInitializer = bias_initializer == null ? tf.zeros_initializer : bias_initializer,
BiasRegularizer = bias_regularizer,
ActivityRegularizer = activity_regularizer,
- Activation = activation ?? keras.activations.Linear
+ Activation = activation ?? keras.activations.Linear,
});
///
@@ -210,7 +209,8 @@ namespace Tensorflow.Keras.Layers
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
- Activation = GetActivationByName(activation)
+ Activation = Activations.GetActivationByName(activation),
+ ActivationName = activation
});
///
@@ -255,7 +255,7 @@ namespace Tensorflow.Keras.Layers
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
- Activation = GetActivationByName(activation)
+ Activation = Activations.GetActivationByName(activation)
});
///
@@ -300,7 +300,7 @@ namespace Tensorflow.Keras.Layers
=> new Dense(new DenseArgs
{
Units = units,
- Activation = GetActivationByName("linear"),
+ Activation = Activations.GetActivationByName("linear"),
ActivationName = "linear"
});
@@ -321,7 +321,7 @@ namespace Tensorflow.Keras.Layers
=> new Dense(new DenseArgs
{
Units = units,
- Activation = GetActivationByName(activation),
+ Activation = Activations.GetActivationByName(activation),
ActivationName = activation,
InputShape = input_shape
});
@@ -666,7 +666,7 @@ namespace Tensorflow.Keras.Layers
=> new SimpleRNN(new SimpleRNNArgs
{
Units = units,
- Activation = GetActivationByName(activation),
+ Activation = Activations.GetActivationByName(activation),
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
@@ -814,24 +814,7 @@ namespace Tensorflow.Keras.Layers
public ILayer GlobalMaxPooling2D(string data_format = "channels_last")
=> new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format });
-
- ///
- /// Get an activation function layer from its name.
- ///
- /// The name of the activation function. One of linear, relu, sigmoid, and tanh.
- ///
-
- Activation GetActivationByName(string name)
- => name switch
- {
- "linear" => keras.activations.Linear,
- "relu" => keras.activations.Relu,
- "sigmoid" => keras.activations.Sigmoid,
- "tanh" => keras.activations.Tanh,
- "softmax" => keras.activations.Softmax,
- _ => throw new Exception($"Activation {name} not found")
- };
-
+ Activation GetActivationByName(string name) => Activations.GetActivationByName(name);
///
/// Get an weights initializer from its name.
///
diff --git a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
index c0b16c81..3b8e1ee8 100644
--- a/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
+++ b/src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs
@@ -58,7 +58,7 @@ namespace Tensorflow.Keras.Layers
var ndims = input_shape.ndim;
foreach (var (idx, x) in enumerate(axis))
if (x < 0)
- axis[idx] = ndims + x;
+ args.Axis.dims[idx] = axis[idx] = ndims + x;
fused = ndims == 4;
diff --git a/src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs b/src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs
similarity index 100%
rename from src/TensorFlowNET.Keras/Layers/Rescaling/Rescaling.cs
rename to src/TensorFlowNET.Keras/Layers/Preprocessing/Rescaling.cs
diff --git a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs
similarity index 79%
rename from src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs
rename to src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs
index 44b338c2..10c15b69 100644
--- a/src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs
+++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping1D.cs
@@ -1,11 +1,12 @@
-using Tensorflow.Keras.ArgsDefinition;
+using Tensorflow.Keras.ArgsDefinition.Reshaping;
using Tensorflow.Keras.Engine;
-namespace Tensorflow.Keras.Layers {
+namespace Tensorflow.Keras.Layers.Reshaping
+{
public class Cropping1D : Layer
{
- CroppingArgs args;
- public Cropping1D(CroppingArgs args) : base(args)
+ Cropping1DArgs args;
+ public Cropping1D(Cropping1DArgs args) : base(args)
{
this.args = args;
}
@@ -41,7 +42,7 @@ namespace Tensorflow.Keras.Layers {
else
{
int crop_start = args.cropping[0], crop_end = args.cropping[1];
- output = output[new Slice(), new Slice(crop_start, (int)(output.shape[1]) - crop_end), new Slice()];
+ output = output[new Slice(), new Slice(crop_start, (int)output.shape[1] - crop_end), new Slice()];
}
return output;
}
@@ -51,12 +52,12 @@ namespace Tensorflow.Keras.Layers {
if (args.cropping.shape[0] == 1)
{
int crop = args.cropping[0];
- return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop * 2), (int)(input_shape[2]));
+ return new Shape((int)input_shape[0], (int)(input_shape[1] - crop * 2), (int)input_shape[2]);
}
else
{
int crop_start = args.cropping[0], crop_end = args.cropping[1];
- return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop_start - crop_end), (int)(input_shape[2]));
+ return new Shape((int)input_shape[0], (int)(input_shape[1] - crop_start - crop_end), (int)input_shape[2]);
}
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs
new file mode 100644
index 00000000..a8d7043e
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping2D.cs
@@ -0,0 +1,140 @@
+using Tensorflow.Keras.ArgsDefinition.Reshaping;
+using Tensorflow.Keras.Engine;
+
+namespace Tensorflow.Keras.Layers.Reshaping
+{
+ ///
+ /// Crop the input along axis 1 and 2.
+ /// For example:
+ /// shape (1, 5, 5, 5) -- crop2D ((1, 2), (1, 3)) --> shape (1, 2, 1, 5)
+ ///
+ public class Cropping2D : Layer
+ {
+ Cropping2DArgs args;
+ public Cropping2D(Cropping2DArgs args) : base(args)
+ {
+ this.args = args;
+ }
+ public override void build(Shape input_shape)
+ {
+ built = true;
+ _buildInputShape = input_shape;
+ }
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ Tensor output = inputs;
+ if (output.rank != 4)
+ {
+ // throw an ValueError exception
+ throw new ValueError("Expected dim=4, found dim=" + output.rank);
+ }
+ if (args.cropping.shape == new Shape(1))
+ {
+ int crop = args.cropping[0];
+ if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
+ {
+ output = output[new Slice(),
+ new Slice(crop, (int)output.shape[1] - crop),
+ new Slice(crop, (int)output.shape[2] - crop),
+ new Slice()];
+ }
+ else
+ {
+ output = output[new Slice(),
+ new Slice(),
+ new Slice(crop, (int)output.shape[2] - crop),
+ new Slice(crop, (int)output.shape[3] - crop)];
+ }
+ }
+ // a tuple of 2 integers
+ else if (args.cropping.shape == new Shape(2))
+ {
+ int crop_1 = args.cropping[0];
+ int crop_2 = args.cropping[1];
+ if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
+ {
+ output = output[new Slice(),
+ new Slice(crop_1, (int)output.shape[1] - crop_1),
+ new Slice(crop_2, (int)output.shape[2] - crop_2),
+ new Slice()];
+ }
+ else
+ {
+ output = output[new Slice(),
+ new Slice(),
+ new Slice(crop_1, (int)output.shape[2] - crop_1),
+ new Slice(crop_2, (int)output.shape[3] - crop_2)];
+ }
+ }
+ else if (args.cropping.shape[0] == 2 && args.cropping.shape[1] == 2)
+ {
+ int x_start = args.cropping[0, 0], x_end = args.cropping[0, 1];
+ int y_start = args.cropping[1, 0], y_end = args.cropping[1, 1];
+ if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
+ {
+ output = output[new Slice(),
+ new Slice(x_start, (int)output.shape[1] - x_end),
+ new Slice(y_start, (int)output.shape[2] - y_end),
+ new Slice()];
+ }
+ else
+ {
+ output = output[new Slice(),
+ new Slice(),
+ new Slice(x_start, (int)output.shape[2] - x_end),
+ new Slice(y_start, (int)output.shape[3] - y_end)
+ ];
+ }
+ }
+ return output;
+ }
+
+ public override Shape ComputeOutputShape(Shape input_shape)
+ {
+ if (args.cropping.shape == new Shape(1))
+ {
+ int crop = args.cropping[0];
+ if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3]);
+ }
+ else
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2);
+ }
+ }
+ // a tuple of 2 integers
+ else if (args.cropping.shape == new Shape(2))
+ {
+ int crop_1 = args.cropping[0], crop_2 = args.cropping[1];
+ if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1 * 2, (int)input_shape[2] - crop_2 * 2, (int)input_shape[3]);
+ }
+ else
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_1 * 2, (int)input_shape[3] - crop_2 * 2);
+ }
+ }
+ else if (args.cropping.shape == new Shape(2, 2))
+ {
+ int crop_1_start = args.cropping[0, 0], crop_1_end = args.cropping[0, 1];
+ int crop_2_start = args.cropping[1, 0], crop_2_end = args.cropping[1, 1];
+ if (args.data_format == Cropping2DArgs.DataFormat.channels_last)
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1] - crop_1_start - crop_1_end,
+ (int)input_shape[2] - crop_2_start - crop_2_end, (int)input_shape[3]);
+ }
+ else
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1],
+ (int)input_shape[2] - crop_1_start - crop_1_end, (int)input_shape[3] - crop_2_start - crop_2_end);
+ }
+ }
+ else
+ {
+ throw new ValueError();
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs
new file mode 100644
index 00000000..796c2dd3
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Layers/Reshaping/Cropping3D.cs
@@ -0,0 +1,150 @@
+using Tensorflow.Keras.ArgsDefinition.Reshaping;
+using Tensorflow.Keras.Engine;
+
+namespace Tensorflow.Keras.Layers.Reshaping
+{
+ ///
+ /// Similar to copping 2D
+ ///
+ public class Cropping3D : Layer
+ {
+ Cropping3DArgs args;
+ public Cropping3D(Cropping3DArgs args) : base(args)
+ {
+ this.args = args;
+ }
+
+ public override void build(Shape input_shape)
+ {
+ built = true;
+ _buildInputShape = input_shape;
+ }
+
+ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
+ {
+ Tensor output = inputs;
+ if (output.rank != 5)
+ {
+ // throw an ValueError exception
+ throw new ValueError("Expected dim=5, found dim=" + output.rank);
+ }
+
+ if (args.cropping.shape == new Shape(1))
+ {
+ int crop = args.cropping[0];
+ if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
+ {
+ output = output[new Slice(),
+ new Slice(crop, (int)output.shape[1] - crop),
+ new Slice(crop, (int)output.shape[2] - crop),
+ new Slice(crop, (int)output.shape[3] - crop),
+ new Slice()];
+ }
+ else
+ {
+ output = output[new Slice(),
+ new Slice(),
+ new Slice(crop, (int)output.shape[2] - crop),
+ new Slice(crop, (int)output.shape[3] - crop),
+ new Slice(crop, (int)output.shape[4] - crop)];
+ }
+
+ }
+ // int[1][3] equivalent to a tuple of 3 integers
+ else if (args.cropping.shape == new Shape(3))
+ {
+ var crop_1 = args.cropping[0];
+ var crop_2 = args.cropping[1];
+ var crop_3 = args.cropping[2];
+ if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
+ {
+ output = output[new Slice(),
+ new Slice(crop_1, (int)output.shape[1] - crop_1),
+ new Slice(crop_2, (int)output.shape[2] - crop_2),
+ new Slice(crop_3, (int)output.shape[3] - crop_3),
+ new Slice()];
+ }
+ else
+ {
+ output = output[new Slice(),
+ new Slice(),
+ new Slice(crop_1, (int)output.shape[2] - crop_1),
+ new Slice(crop_2, (int)output.shape[3] - crop_2),
+ new Slice(crop_3, (int)output.shape[4] - crop_3)];
+ }
+ }
+ else if (args.cropping.shape[0] == 3 && args.cropping.shape[1] == 2)
+ {
+ int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
+ int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
+ int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
+ if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
+ {
+ output = output[new Slice(),
+ new Slice(x, (int)output.shape[1] - x_end),
+ new Slice(y, (int)output.shape[2] - y_end),
+ new Slice(z, (int)output.shape[3] - z_end),
+ new Slice()];
+ }
+ else
+ {
+ output = output[new Slice(),
+ new Slice(),
+ new Slice(x, (int)output.shape[2] - x_end),
+ new Slice(y, (int)output.shape[3] - y_end),
+ new Slice(z, (int)output.shape[4] - z_end)
+ ];
+ }
+ }
+ return output;
+ }
+ public override Shape ComputeOutputShape(Shape input_shape)
+ {
+ if (args.cropping.shape == new Shape(1))
+ {
+ int crop = args.cropping[0];
+ if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1] - crop * 2, (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4]);
+ }
+ else
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop * 2, (int)input_shape[3] - crop * 2, (int)input_shape[4] - crop * 2);
+ }
+ }
+ // int[1][3] equivalent to a tuple of 3 integers
+ else if (args.cropping.shape == new Shape(3))
+ {
+ var crop_start_1 = args.cropping[0];
+ var crop_start_2 = args.cropping[1];
+ var crop_start_3 = args.cropping[2];
+ if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1] - crop_start_1 * 2, (int)input_shape[2] - crop_start_2 * 2, (int)input_shape[3] - crop_start_3 * 2, (int)input_shape[4]);
+ }
+ else
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - crop_start_1 * 2, (int)input_shape[3] - crop_start_2 * 2, (int)input_shape[4] - crop_start_3 * 2);
+ }
+ }
+ else if (args.cropping.shape == new Shape(3, 2))
+ {
+ int x = args.cropping[0, 0], x_end = args.cropping[0, 1];
+ int y = args.cropping[1, 0], y_end = args.cropping[1, 1];
+ int z = args.cropping[2, 0], z_end = args.cropping[2, 1];
+ if (args.data_format == Cropping3DArgs.DataFormat.channels_last)
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1] - x - x_end, (int)input_shape[2] - y - y_end, (int)input_shape[3] - z - z_end, (int)input_shape[4]);
+ }
+ else
+ {
+ return new Shape((int)input_shape[0], (int)input_shape[1], (int)input_shape[2] - x - x_end, (int)input_shape[3] - y - y_end, (int)input_shape[4] - z - z_end);
+ }
+ }
+ else
+ {
+ throw new ValueError();
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
similarity index 87%
rename from src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs
rename to src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
index b7d97384..59555e62 100644
--- a/src/TensorFlowNET.Keras/Layers/Lstm/LSTM.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTM.cs
@@ -1,9 +1,8 @@
using System.Linq;
-using Tensorflow.Keras.ArgsDefinition.Lstm;
+using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
-using Tensorflow.Keras.Layers.Rnn;
-namespace Tensorflow.Keras.Layers.Lstm
+namespace Tensorflow.Keras.Layers.Rnn
{
///
/// Long Short-Term Memory layer - Hochreiter 1997.
diff --git a/src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs
similarity index 72%
rename from src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs
rename to src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs
index 3cd35a09..a622c91a 100644
--- a/src/TensorFlowNET.Keras/Layers/Lstm/LSTMCell.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs
@@ -1,7 +1,7 @@
-using Tensorflow.Keras.ArgsDefinition.Lstm;
+using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
-namespace Tensorflow.Keras.Layers.Lstm
+namespace Tensorflow.Keras.Layers.Rnn
{
public class LSTMCell : Layer
{
diff --git a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
index 877c3599..6b755ece 100644
--- a/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
+++ b/src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs
@@ -3,7 +3,6 @@ using System.Collections.Generic;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Rnn;
using Tensorflow.Keras.Engine;
-using Tensorflow.Keras.Layers.Lstm;
// from tensorflow.python.distribute import distribution_strategy_context as ds_context;
namespace Tensorflow.Keras.Layers.Rnn
diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs
deleted file mode 100644
index 288a92b3..00000000
--- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs
+++ /dev/null
@@ -1,82 +0,0 @@
-using Microsoft.VisualStudio.TestTools.UnitTesting;
-using Tensorflow.NumPy;
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
-using Tensorflow;
-using static Tensorflow.Binding;
-using static Tensorflow.KerasApi;
-using Tensorflow.Keras;
-using Tensorflow.Keras.ArgsDefinition;
-using Tensorflow.Keras.Engine;
-using Tensorflow.Keras.Layers;
-using Tensorflow.Keras.Losses;
-using Tensorflow.Keras.Metrics;
-using Tensorflow.Keras.Optimizers;
-using Tensorflow.Operations;
-
-namespace TensorFlowNET.Keras.UnitTest.SaveModel;
-
-[TestClass]
-public class SequentialModelTest
-{
- [TestMethod]
- public void SimpleModelFromAutoCompile()
- {
- var inputs = new KerasInterface().Input((28, 28, 1));
- var x = new Flatten(new FlattenArgs()).Apply(inputs);
- x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x);
- x = new LayersApi().Dense(units: 10).Apply(x);
- var outputs = new LayersApi().Softmax(axis: 1).Apply(x);
- var model = new KerasInterface().Model(inputs, outputs);
-
- model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
-
- var data_loader = new MnistModelLoader();
- var num_epochs = 1;
- var batch_size = 50;
-
- var dataset = data_loader.LoadAsync(new ModelLoadSetting
- {
- TrainDir = "mnist",
- OneHot = false,
- ValidationSize = 10000,
- }).Result;
-
- model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
-
- model.save("C:\\Work\\tf.net\\tf_test\\tf.net.simple.compile", save_format: "tf");
- }
-
- [TestMethod]
- public void SimpleModelFromSequential()
- {
- Model model = KerasApi.keras.Sequential(new List()
- {
- keras.layers.InputLayer((28, 28, 1)),
- keras.layers.Flatten(),
- keras.layers.Dense(100, "relu"),
- keras.layers.Dense(10),
- keras.layers.Softmax(1)
- });
-
- model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
-
- var data_loader = new MnistModelLoader();
- var num_epochs = 1;
- var batch_size = 50;
-
- var dataset = data_loader.LoadAsync(new ModelLoadSetting
- {
- TrainDir = "mnist",
- OneHot = false,
- ValidationSize = 10000,
- }).Result;
-
- model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
-
- model.save("C:\\Work\\tf.net\\tf_test\\tf.net.simple.sequential", save_format: "tf");
- }
-}
\ No newline at end of file