diff --git a/src/TensorFlowNET.Core/Keras/Activations/Activations.cs b/src/TensorFlowNET.Core/Keras/Activations/Activations.cs
index b8f321e6..3dde625e 100644
--- a/src/TensorFlowNET.Core/Keras/Activations/Activations.cs
+++ b/src/TensorFlowNET.Core/Keras/Activations/Activations.cs
@@ -1,4 +1,44 @@
-namespace Tensorflow.Keras
+using Newtonsoft.Json;
+using System.Reflection;
+using System.Runtime.Versioning;
+using Tensorflow.Keras.Common;
+
+namespace Tensorflow.Keras
{
- public delegate Tensor Activation(Tensor features, string name = null);
+ [JsonConverter(typeof(CustomizedActivationJsonConverter))]
+ public class Activation
+ {
+ public string Name { get; set; }
+ ///
+ /// The parameters are `features` and `name`.
+ ///
+ public Func ActivationFunction { get; set; }
+
+ public Tensor Apply(Tensor input, string name = null) => ActivationFunction(input, name);
+
+ public static implicit operator Activation(Func func)
+ {
+ return new Activation()
+ {
+ Name = func.GetMethodInfo().Name,
+ ActivationFunction = func
+ };
+ }
+ }
+
+ public interface IActivationsApi
+ {
+ Activation GetActivationFromName(string name);
+ Activation Linear { get; }
+
+ Activation Relu { get; }
+
+ Activation Sigmoid { get; }
+
+ Activation Softmax { get; }
+
+ Activation Tanh { get; }
+
+ Activation Mish { get; }
+ }
}
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs
index 08d563c1..a0724630 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Convolution/ConvolutionalArgs.cs
@@ -26,27 +26,8 @@ namespace Tensorflow.Keras.ArgsDefinition
public Shape DilationRate { get; set; } = (1, 1);
[JsonProperty("groups")]
public int Groups { get; set; } = 1;
- public Activation Activation { get; set; }
- private string _activationName;
[JsonProperty("activation")]
- public string ActivationName
- {
- get
- {
- if (string.IsNullOrEmpty(_activationName))
- {
- return Activation.Method.Name;
- }
- else
- {
- return _activationName;
- }
- }
- set
- {
- _activationName = value;
- }
- }
+ public Activation Activation { get; set; }
[JsonProperty("use_bias")]
public bool UseBias { get; set; }
[JsonProperty("kernel_initializer")]
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs
index 8f4facbd..0caa76ef 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs
@@ -18,28 +18,8 @@ namespace Tensorflow.Keras.ArgsDefinition
///
/// Activation function to use.
///
- public Activation Activation { get; set; }
-
- private string _activationName;
[JsonProperty("activation")]
- public string ActivationName
- {
- get
- {
- if (string.IsNullOrEmpty(_activationName))
- {
- return Activation.Method.Name;
- }
- else
- {
- return _activationName;
- }
- }
- set
- {
- _activationName = value;
- }
- }
+ public Activation Activation { get; set; }
///
/// Whether the layer uses a bias vector.
diff --git a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs
index 9817e9c6..e6030972 100644
--- a/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs
+++ b/src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/EinsumDenseArgs.cs
@@ -35,27 +35,8 @@ namespace Tensorflow.Keras.ArgsDefinition.Core
///
/// Activation function to use.
///
- public Activation Activation { get; set; }
- private string _activationName;
[JsonProperty("activation")]
- public string ActivationName
- {
- get
- {
- if (string.IsNullOrEmpty(_activationName))
- {
- return Activation.Method.Name;
- }
- else
- {
- return _activationName;
- }
- }
- set
- {
- _activationName = value;
- }
- }
+ public Activation Activation { get; set; }
///
/// Initializer for the `kernel` weights matrix.
diff --git a/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs
index 1bc13caf..04ee79e3 100644
--- a/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs
+++ b/src/TensorFlowNET.Core/Keras/Common/CustomizedActivationJsonConverter.cs
@@ -4,6 +4,7 @@ using Newtonsoft.Json.Linq;
using System;
using System.Collections.Generic;
using System.Text;
+using static Tensorflow.Binding;
namespace Tensorflow.Keras.Common
{
@@ -31,20 +32,19 @@ namespace Tensorflow.Keras.Common
}
else
{
- var token = JToken.FromObject((value as Activation)!.GetType().Name);
+ var token = JToken.FromObject(((Activation)value).Name);
token.WriteTo(writer);
}
}
public override object? ReadJson(JsonReader reader, Type objectType, object? existingValue, JsonSerializer serializer)
{
- throw new NotImplementedException();
- //var dims = serializer.Deserialize(reader, typeof(string));
- //if (dims is null)
- //{
- // throw new ValueError("Cannot deserialize 'null' to `Activation`.");
- //}
- //return new Shape((long[])(dims!));
+ var activationName = serializer.Deserialize(reader);
+ if (tf.keras is null)
+ {
+ throw new RuntimeError("Tensorflow.Keras is not loaded, please install it first.");
+ }
+ return tf.keras.activations.GetActivationFromName(string.IsNullOrEmpty(activationName) ? "linear" : activationName);
}
}
}
diff --git a/src/TensorFlowNET.Core/Keras/IKerasApi.cs b/src/TensorFlowNET.Core/Keras/IKerasApi.cs
index 0b46d27d..db8deb24 100644
--- a/src/TensorFlowNET.Core/Keras/IKerasApi.cs
+++ b/src/TensorFlowNET.Core/Keras/IKerasApi.cs
@@ -16,6 +16,7 @@ namespace Tensorflow.Keras
IInitializersApi initializers { get; }
ILayersApi layers { get; }
ILossesApi losses { get; }
+ IActivationsApi activations { get; }
IOptimizerApi optimizers { get; }
IMetricsApi metrics { get; }
IModelsApi models { get; }
diff --git a/src/TensorFlowNET.Keras/Activations.cs b/src/TensorFlowNET.Keras/Activations.cs
index 37bddac7..00de728f 100644
--- a/src/TensorFlowNET.Keras/Activations.cs
+++ b/src/TensorFlowNET.Keras/Activations.cs
@@ -6,45 +6,61 @@ using static Tensorflow.Binding;
namespace Tensorflow.Keras
{
- public class Activations
+ public class Activations: IActivationsApi
{
private static Dictionary _nameActivationMap;
- private static Dictionary _activationNameMap;
- private static Activation _linear = (features, name) => features;
- private static Activation _relu = (features, name)
- => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features));
- private static Activation _sigmoid = (features, name)
- => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features));
- private static Activation _softmax = (features, name)
- => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features));
- private static Activation _tanh = (features, name)
- => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features));
- private static Activation _mish = (features, name)
- => features * tf.math.tanh(tf.math.softplus(features));
+ private static Activation _linear = new Activation()
+ {
+ Name = "linear",
+ ActivationFunction = (features, name) => features
+ };
+ private static Activation _relu = new Activation()
+ {
+ Name = "relu",
+ ActivationFunction = (features, name) => tf.Context.ExecuteOp("Relu", name, new ExecuteOpArgs(features))
+ };
+ private static Activation _sigmoid = new Activation()
+ {
+ Name = "sigmoid",
+ ActivationFunction = (features, name) => tf.Context.ExecuteOp("Sigmoid", name, new ExecuteOpArgs(features))
+ };
+ private static Activation _softmax = new Activation()
+ {
+ Name = "softmax",
+ ActivationFunction = (features, name) => tf.Context.ExecuteOp("Softmax", name, new ExecuteOpArgs(features))
+ };
+ private static Activation _tanh = new Activation()
+ {
+ Name = "tanh",
+ ActivationFunction = (features, name) => tf.Context.ExecuteOp("Tanh", name, new ExecuteOpArgs(features))
+ };
+ private static Activation _mish = new Activation()
+ {
+ Name = "mish",
+ ActivationFunction = (features, name) => features * tf.math.tanh(tf.math.softplus(features))
+ };
///
/// Register the name-activation mapping in this static class.
///
///
///
- private static void RegisterActivation(string name, Activation activation)
+ private static void RegisterActivation(Activation activation)
{
- _nameActivationMap[name] = activation;
- _activationNameMap[activation] = name;
+ _nameActivationMap[activation.Name] = activation;
}
static Activations()
{
_nameActivationMap = new Dictionary();
- _activationNameMap= new Dictionary();
- RegisterActivation("relu", _relu);
- RegisterActivation("linear", _linear);
- RegisterActivation("sigmoid", _sigmoid);
- RegisterActivation("softmax", _softmax);
- RegisterActivation("tanh", _tanh);
- RegisterActivation("mish", _mish);
+ RegisterActivation(_relu);
+ RegisterActivation(_linear);
+ RegisterActivation(_sigmoid);
+ RegisterActivation(_softmax);
+ RegisterActivation(_tanh);
+ RegisterActivation(_mish);
}
public Activation Linear => _linear;
@@ -59,7 +75,7 @@ namespace Tensorflow.Keras
public Activation Mish => _mish;
- public static Activation GetActivationByName(string name)
+ public Activation GetActivationFromName(string name)
{
if (!_nameActivationMap.TryGetValue(name, out var res))
{
@@ -70,17 +86,5 @@ namespace Tensorflow.Keras
return res;
}
}
-
- public static string GetNameByActivation(Activation activation)
- {
- if(!_activationNameMap.TryGetValue(activation, out var name))
- {
- throw new Exception($"Activation {activation} not found");
- }
- else
- {
- return name;
- }
- }
}
}
diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs
index 8bd1e682..9f1746d8 100644
--- a/src/TensorFlowNET.Keras/KerasInterface.cs
+++ b/src/TensorFlowNET.Keras/KerasInterface.cs
@@ -45,7 +45,7 @@ namespace Tensorflow.Keras
public Regularizers regularizers { get; } = new Regularizers();
public ILayersApi layers { get; } = new LayersApi();
public ILossesApi losses { get; } = new LossesApi();
- public Activations activations { get; } = new Activations();
+ public IActivationsApi activations { get; } = new Activations();
public Preprocessing preprocessing { get; } = new Preprocessing();
ThreadLocal _backend = new ThreadLocal(() => new BackendImpl());
public BackendImpl backend => _backend.Value;
diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
index b8286be6..7b281b28 100644
--- a/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
+++ b/src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs
@@ -110,7 +110,7 @@ namespace Tensorflow.Keras.Layers
throw new NotImplementedException("");
if (activation != null)
- return activation(outputs);
+ return activation.Apply(outputs);
return outputs;
}
diff --git a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
index 933aa9cf..8f6a6c5b 100644
--- a/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
+++ b/src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs
@@ -117,7 +117,7 @@ namespace Tensorflow.Keras.Layers
}
if (activation != null)
- outputs = activation(outputs);
+ outputs = activation.Apply(outputs);
return outputs;
}
diff --git a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs
index 56fde9f2..decdcb1d 100644
--- a/src/TensorFlowNET.Keras/Layers/Core/Dense.cs
+++ b/src/TensorFlowNET.Keras/Layers/Core/Dense.cs
@@ -81,7 +81,7 @@ namespace Tensorflow.Keras.Layers
if (args.UseBias)
outputs = tf.nn.bias_add(outputs, bias);
if (args.Activation != null)
- outputs = activation(outputs);
+ outputs = activation.Apply(outputs);
return outputs;
}
diff --git a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
index af71ddf9..c928591f 100644
--- a/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
+++ b/src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs
@@ -193,7 +193,7 @@ namespace Tensorflow.Keras.Layers
if (this.bias != null)
ret += this.bias.AsTensor();
if (this.activation != null)
- ret = this.activation(ret);
+ ret = this.activation.Apply(ret);
return ret;
}
///
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index cf689edf..22fd661d 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -109,7 +109,7 @@ namespace Tensorflow.Keras.Layers
DilationRate = dilation_rate,
Groups = groups,
UseBias = use_bias,
- Activation = Activations.GetActivationByName(activation),
+ Activation = keras.activations.GetActivationFromName(activation),
KernelInitializer = GetInitializerByName(kernel_initializer),
BiasInitializer = GetInitializerByName(bias_initializer)
});
@@ -211,8 +211,7 @@ namespace Tensorflow.Keras.Layers
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
- Activation = Activations.GetActivationByName(activation),
- ActivationName = activation
+ Activation = keras.activations.GetActivationFromName(activation)
});
///
@@ -257,7 +256,7 @@ namespace Tensorflow.Keras.Layers
UseBias = use_bias,
KernelInitializer = GetInitializerByName(kernel_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
- Activation = Activations.GetActivationByName(activation)
+ Activation = keras.activations.GetActivationFromName(activation)
});
///
@@ -302,8 +301,7 @@ namespace Tensorflow.Keras.Layers
=> new Dense(new DenseArgs
{
Units = units,
- Activation = Activations.GetActivationByName("linear"),
- ActivationName = "linear"
+ Activation = keras.activations.GetActivationFromName("linear")
});
///
@@ -323,8 +321,7 @@ namespace Tensorflow.Keras.Layers
=> new Dense(new DenseArgs
{
Units = units,
- Activation = Activations.GetActivationByName(activation),
- ActivationName = activation,
+ Activation = keras.activations.GetActivationFromName(activation),
InputShape = input_shape
});
@@ -704,7 +701,7 @@ namespace Tensorflow.Keras.Layers
=> new SimpleRNN(new SimpleRNNArgs
{
Units = units,
- Activation = Activations.GetActivationByName(activation),
+ Activation = keras.activations.GetActivationFromName(activation),
KernelInitializer = GetInitializerByName(kernel_initializer),
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
BiasInitializer = GetInitializerByName(bias_initializer),
@@ -852,7 +849,6 @@ namespace Tensorflow.Keras.Layers
public ILayer GlobalMaxPooling2D(string data_format = "channels_last")
=> new GlobalMaxPooling2D(new Pooling2DArgs { DataFormat = data_format });
- Activation GetActivationByName(string name) => Activations.GetActivationByName(name);
///
/// Get an weights initializer from its name.
///
diff --git a/test/TensorFlowNET.Keras.UnitTest/GradientTest.cs b/test/TensorFlowNET.Keras.UnitTest/GradientTest.cs
index f20eae0e..6ea2eb85 100644
--- a/test/TensorFlowNET.Keras.UnitTest/GradientTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/GradientTest.cs
@@ -5,6 +5,8 @@ using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.NumPy;
+using System;
+using Tensorflow.Keras.Optimizers;
namespace TensorFlowNET.Keras.UnitTest;
@@ -40,7 +42,7 @@ public class GradientTest : EagerModeTestBase
}
[TestMethod]
- public void GetGradient_Test()
+ public void GetGradientTest()
{
var numStates = 3;
var numActions = 1;
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs
index 1f45c518..6fe9ca50 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/ActivationTest.cs
@@ -102,7 +102,7 @@ namespace TensorFlowNET.Keras.UnitTest {
public void Mish()
{
var x = tf.constant(new[] { 1.0, 0.0, 1.0 }, dtype: tf.float32);
- var output = keras.activations.Mish(x);
+ var output = keras.activations.Mish.Apply(x);
Assert.AreEqual(new[] { 0.86509836f, 0f, 0.86509836f }, output.numpy());
}
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
index 0c912607..9a6f35f6 100644
--- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelSave.cs
@@ -18,7 +18,7 @@ public class SequentialModelSave
{
var inputs = tf.keras.layers.Input((28, 28, 1));
var x = tf.keras.layers.Flatten().Apply(inputs);
- x = tf.keras.layers.Dense(100, activation: tf.nn.relu).Apply(x);
+ x = tf.keras.layers.Dense(100, activation: "relu").Apply(x);
x = tf.keras.layers.Dense(units: 10).Apply(x);
var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x);
var model = tf.keras.Model(inputs, outputs);
@@ -110,7 +110,7 @@ public class SequentialModelSave
tf.keras.layers.Softmax(1)
});
- model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
+ model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
var num_epochs = 1;
var batch_size = 8;