From 49fa114f2ad987722273133d8e13d43af99bea0d Mon Sep 17 00:00:00 2001 From: Ahmed Elsayed Date: Tue, 21 Dec 2021 18:20:05 +0200 Subject: [PATCH] Initial regularizers. --- .../Keras/ArgsDefinition/Core/DenseArgs.cs | 4 ++-- .../Keras/Regularizers/RegularizerArgs.cs | 7 ++++++ src/TensorFlowNET.Keras/Engine/Layer.cs | 11 +++++---- src/TensorFlowNET.Keras/Regularizers/L1.cs | 19 +++++++++++++++ src/TensorFlowNET.Keras/Regularizers/L1L2.cs | 24 +++++++++++++++++++ src/TensorFlowNET.Keras/Regularizers/L2.cs | 6 ++--- 6 files changed, 61 insertions(+), 10 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Regularizers/L1.cs create mode 100644 src/TensorFlowNET.Keras/Regularizers/L1L2.cs diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs index 7ff89c94..e9b3c2fd 100644 --- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs +++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs @@ -33,12 +33,12 @@ namespace Tensorflow.Keras.ArgsDefinition /// /// Regularizer function applied to the `kernel` weights matrix. /// - public IInitializer KernelRegularizer { get; set; } + public IRegularizer KernelRegularizer { get; set; } /// /// Regularizer function applied to the bias vector. /// - public IInitializer BiasRegularizer { get; set; } + public IRegularizer BiasRegularizer { get; set; } /// /// Constraint function applied to the `kernel` weights matrix. diff --git a/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs b/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs index 90100fe0..8e7e89b1 100644 --- a/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs +++ b/src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs @@ -2,5 +2,12 @@ { public class RegularizerArgs { + public Tensor X { get; set; } + + + public RegularizerArgs(Tensor x) + { + X = x; + } } } diff --git a/src/TensorFlowNET.Keras/Engine/Layer.cs b/src/TensorFlowNET.Keras/Engine/Layer.cs index e9d58b6f..75bb8f12 100644 --- a/src/TensorFlowNET.Keras/Engine/Layer.cs +++ b/src/TensorFlowNET.Keras/Engine/Layer.cs @@ -203,7 +203,7 @@ namespace Tensorflow.Keras.Engine protected virtual void add_loss(Func losses) { - + } /// @@ -214,10 +214,13 @@ namespace Tensorflow.Keras.Engine /// void _handle_weight_regularization(string name, IVariableV1 variable, IRegularizer regularizer) { - add_loss(() => regularizer.Apply(new RegularizerArgs - { - })); + add_loss(() => tf_with(ops.name_scope(name + "/Regularizer"), scope => + regularizer.Apply(new RegularizerArgs(variable.AsTensor()) + { + + }) + )); } /*protected virtual void add_update(Tensor[] updates, bool inputs = false) diff --git a/src/TensorFlowNET.Keras/Regularizers/L1.cs b/src/TensorFlowNET.Keras/Regularizers/L1.cs new file mode 100644 index 00000000..0f904b6f --- /dev/null +++ b/src/TensorFlowNET.Keras/Regularizers/L1.cs @@ -0,0 +1,19 @@ +using System; + +namespace Tensorflow.Keras +{ + public class L1 : IRegularizer + { + float l1; + + public L1(float l1 = 0.01f) + { + this.l1 = l1; + } + + public Tensor Apply(RegularizerArgs args) + { + return l1 * math_ops.reduce_sum(math_ops.abs(args.X)); + } + } +} diff --git a/src/TensorFlowNET.Keras/Regularizers/L1L2.cs b/src/TensorFlowNET.Keras/Regularizers/L1L2.cs new file mode 100644 index 00000000..f619f158 --- /dev/null +++ b/src/TensorFlowNET.Keras/Regularizers/L1L2.cs @@ -0,0 +1,24 @@ +using System; +using static Tensorflow.Binding; +namespace Tensorflow.Keras +{ + public class L1L2 : IRegularizer + { + float l1; + float l2; + + public L1L2(float l1 = 0.0f, float l2 = 0.0f) + { + this.l1 = l1; + this.l2 = l2; + + } + public Tensor Apply(RegularizerArgs args) + { + Tensor regularization = tf.constant(0.0, args.X.dtype); + regularization += l1 * math_ops.reduce_sum(math_ops.abs(args.X)); + regularization += l2 * math_ops.reduce_sum(math_ops.square(args.X)); + return regularization; + } + } +} diff --git a/src/TensorFlowNET.Keras/Regularizers/L2.cs b/src/TensorFlowNET.Keras/Regularizers/L2.cs index 9e293e89..034bbd23 100644 --- a/src/TensorFlowNET.Keras/Regularizers/L2.cs +++ b/src/TensorFlowNET.Keras/Regularizers/L2.cs @@ -1,6 +1,4 @@ -using System; - -namespace Tensorflow.Keras +namespace Tensorflow.Keras { public class L2 : IRegularizer { @@ -13,7 +11,7 @@ namespace Tensorflow.Keras public Tensor Apply(RegularizerArgs args) { - throw new NotImplementedException(); + return l2 * math_ops.reduce_sum(math_ops.square(args.X)); } } }