Browse Source

Initial regularizers.

pull/893/head
Ahmed Elsayed 3 years ago
parent
commit
49fa114f2a
6 changed files with 61 additions and 10 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs
  2. +7
    -0
      src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs
  3. +7
    -4
      src/TensorFlowNET.Keras/Engine/Layer.cs
  4. +19
    -0
      src/TensorFlowNET.Keras/Regularizers/L1.cs
  5. +24
    -0
      src/TensorFlowNET.Keras/Regularizers/L1L2.cs
  6. +2
    -4
      src/TensorFlowNET.Keras/Regularizers/L2.cs

+ 2
- 2
src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs View File

@@ -33,12 +33,12 @@ namespace Tensorflow.Keras.ArgsDefinition
/// <summary> /// <summary>
/// Regularizer function applied to the `kernel` weights matrix. /// Regularizer function applied to the `kernel` weights matrix.
/// </summary> /// </summary>
public IInitializer KernelRegularizer { get; set; }
public IRegularizer KernelRegularizer { get; set; }


/// <summary> /// <summary>
/// Regularizer function applied to the bias vector. /// Regularizer function applied to the bias vector.
/// </summary> /// </summary>
public IInitializer BiasRegularizer { get; set; }
public IRegularizer BiasRegularizer { get; set; }


/// <summary> /// <summary>
/// Constraint function applied to the `kernel` weights matrix. /// Constraint function applied to the `kernel` weights matrix.


+ 7
- 0
src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs View File

@@ -2,5 +2,12 @@
{ {
public class RegularizerArgs public class RegularizerArgs
{ {
public Tensor X { get; set; }


public RegularizerArgs(Tensor x)
{
X = x;
}
} }
} }

+ 7
- 4
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -203,7 +203,7 @@ namespace Tensorflow.Keras.Engine


protected virtual void add_loss(Func<Tensor> losses) protected virtual void add_loss(Func<Tensor> losses)
{ {
} }


/// <summary> /// <summary>
@@ -214,10 +214,13 @@ namespace Tensorflow.Keras.Engine
/// <param name="regularizer"></param> /// <param name="regularizer"></param>
void _handle_weight_regularization(string name, IVariableV1 variable, IRegularizer regularizer) void _handle_weight_regularization(string name, IVariableV1 variable, IRegularizer regularizer)
{ {
add_loss(() => regularizer.Apply(new RegularizerArgs
{


}));
add_loss(() => tf_with(ops.name_scope(name + "/Regularizer"), scope =>
regularizer.Apply(new RegularizerArgs(variable.AsTensor())
{

})
));
} }


/*protected virtual void add_update(Tensor[] updates, bool inputs = false) /*protected virtual void add_update(Tensor[] updates, bool inputs = false)


+ 19
- 0
src/TensorFlowNET.Keras/Regularizers/L1.cs View File

@@ -0,0 +1,19 @@
using System;

namespace Tensorflow.Keras
{
public class L1 : IRegularizer
{
float l1;

public L1(float l1 = 0.01f)
{
this.l1 = l1;
}

public Tensor Apply(RegularizerArgs args)
{
return l1 * math_ops.reduce_sum(math_ops.abs(args.X));
}
}
}

+ 24
- 0
src/TensorFlowNET.Keras/Regularizers/L1L2.cs View File

@@ -0,0 +1,24 @@
using System;
using static Tensorflow.Binding;
namespace Tensorflow.Keras
{
public class L1L2 : IRegularizer
{
float l1;
float l2;

public L1L2(float l1 = 0.0f, float l2 = 0.0f)
{
this.l1 = l1;
this.l2 = l2;

}
public Tensor Apply(RegularizerArgs args)
{
Tensor regularization = tf.constant(0.0, args.X.dtype);
regularization += l1 * math_ops.reduce_sum(math_ops.abs(args.X));
regularization += l2 * math_ops.reduce_sum(math_ops.square(args.X));
return regularization;
}
}
}

+ 2
- 4
src/TensorFlowNET.Keras/Regularizers/L2.cs View File

@@ -1,6 +1,4 @@
using System;

namespace Tensorflow.Keras
namespace Tensorflow.Keras
{ {
public class L2 : IRegularizer public class L2 : IRegularizer
{ {
@@ -13,7 +11,7 @@ namespace Tensorflow.Keras


public Tensor Apply(RegularizerArgs args) public Tensor Apply(RegularizerArgs args)
{ {
throw new NotImplementedException();
return l2 * math_ops.reduce_sum(math_ops.square(args.X));
} }
} }
} }

Loading…
Cancel
Save