diff --git a/src/TensorFlowNET.Core/GlobalUsing.cs b/src/TensorFlowNET.Core/GlobalUsing.cs
new file mode 100644
index 00000000..fe77202c
--- /dev/null
+++ b/src/TensorFlowNET.Core/GlobalUsing.cs
@@ -0,0 +1,3 @@
+global using System;
+global using System.Collections.Generic;
+global using System.Text;
diff --git a/src/TensorFlowNET.Core/Keras/IKerasApi.cs b/src/TensorFlowNET.Core/Keras/IKerasApi.cs
index 7f85f02f..49ec9a5f 100644
--- a/src/TensorFlowNET.Core/Keras/IKerasApi.cs
+++ b/src/TensorFlowNET.Core/Keras/IKerasApi.cs
@@ -2,12 +2,14 @@
using System.Collections.Generic;
using System.Text;
using Tensorflow.Keras.Layers;
+using Tensorflow.Keras.Losses;
namespace Tensorflow.Keras
{
public interface IKerasApi
{
public ILayersApi layers { get; }
+ public ILossesApi losses { get; }
public IInitializersApi initializers { get; }
}
}
diff --git a/src/TensorFlowNET.Core/Keras/Losses/ILossFunc.cs b/src/TensorFlowNET.Core/Keras/Losses/ILossFunc.cs
new file mode 100644
index 00000000..408c7ca1
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Losses/ILossFunc.cs
@@ -0,0 +1,8 @@
+namespace Tensorflow.Keras.Losses;
+
+public interface ILossFunc
+{
+ public string Reduction { get; }
+ public string Name { get; }
+ Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null);
+}
diff --git a/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs b/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs
new file mode 100644
index 00000000..c4249336
--- /dev/null
+++ b/src/TensorFlowNET.Core/Keras/Losses/ILossesApi.cs
@@ -0,0 +1,41 @@
+namespace Tensorflow.Keras.Losses;
+
+public interface ILossesApi
+{
+ ILossFunc BinaryCrossentropy(bool from_logits = false,
+ float label_smoothing = 0f,
+ int axis = -1,
+ string reduction = "auto",
+ string name = "binary_crossentropy");
+
+ ILossFunc SparseCategoricalCrossentropy(string reduction = null,
+ string name = null,
+ bool from_logits = false);
+
+ ILossFunc CategoricalCrossentropy(string reduction = null,
+ string name = null,
+ bool from_logits = false);
+
+ ILossFunc MeanSquaredError(string reduction = null,
+ string name = null);
+
+ ILossFunc MeanSquaredLogarithmicError(string reduction = null,
+ string name = null);
+
+ ILossFunc MeanAbsolutePercentageError(string reduction = null,
+ string name = null);
+
+ ILossFunc MeanAbsoluteError(string reduction = null,
+ string name = null);
+
+ ILossFunc CosineSimilarity(string reduction = null,
+ int axis = -1,
+ string name = null);
+
+ ILossFunc Huber(string reduction = null,
+ string name = null,
+ Tensor delta = null);
+
+ ILossFunc LogCosh(string reduction = null,
+ string name = null);
+}
diff --git a/src/TensorFlowNET.Keras/BackendImpl.cs b/src/TensorFlowNET.Keras/BackendImpl.cs
index a62e8196..0c9da015 100644
--- a/src/TensorFlowNET.Keras/BackendImpl.cs
+++ b/src/TensorFlowNET.Keras/BackendImpl.cs
@@ -276,6 +276,20 @@ namespace Tensorflow.Keras
return -math_ops.reduce_sum(target * math_ops.log(output), new Axis(axis));
}
+ public Tensor binary_crossentropy(Tensor target, Tensor output, bool from_logits = false)
+ {
+ if (from_logits)
+ return tf.nn.sigmoid_cross_entropy_with_logits(labels: target, logits: output);
+
+ var epsilon_ = constant_op.constant(epsilon(), dtype: output.dtype.as_base_dtype());
+ output = tf.clip_by_value(output, epsilon_, 1.0f - epsilon_);
+
+ // Compute cross entropy from probabilities.
+ var bce = target * tf.math.log(output + epsilon());
+ bce += (1 - target) * tf.math.log(1 - output + epsilon());
+ return -bce;
+ }
+
///
/// Resizes the images contained in a 4D tensor.
///
diff --git a/src/TensorFlowNET.Keras/GlobalUsing.cs b/src/TensorFlowNET.Keras/GlobalUsing.cs
new file mode 100644
index 00000000..72ff8b28
--- /dev/null
+++ b/src/TensorFlowNET.Keras/GlobalUsing.cs
@@ -0,0 +1,5 @@
+global using System;
+global using System.Collections.Generic;
+global using System.Text;
+global using static Tensorflow.Binding;
+global using static Tensorflow.KerasApi;
\ No newline at end of file
diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs
index 8dde1ab4..4e0c612b 100644
--- a/src/TensorFlowNET.Keras/KerasInterface.cs
+++ b/src/TensorFlowNET.Keras/KerasInterface.cs
@@ -21,7 +21,7 @@ namespace Tensorflow.Keras
public IInitializersApi initializers { get; } = new InitializersApi();
public Regularizers regularizers { get; } = new Regularizers();
public ILayersApi layers { get; } = new LayersApi();
- public LossesApi losses { get; } = new LossesApi();
+ public ILossesApi losses { get; } = new LossesApi();
public Activations activations { get; } = new Activations();
public Preprocessing preprocessing { get; } = new Preprocessing();
ThreadLocal _backend = new ThreadLocal(() => new BackendImpl());
diff --git a/src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs
new file mode 100644
index 00000000..ff7bb6b7
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Losses/BinaryCrossentropy.cs
@@ -0,0 +1,24 @@
+namespace Tensorflow.Keras.Losses;
+
+public class BinaryCrossentropy : LossFunctionWrapper, ILossFunc
+{
+ float label_smoothing;
+ public BinaryCrossentropy(
+ bool from_logits = false,
+ float label_smoothing = 0,
+ string reduction = null,
+ string name = null) :
+ base(reduction: reduction,
+ name: name == null ? "binary_crossentropy" : name,
+ from_logits: from_logits)
+ {
+ this.label_smoothing = label_smoothing;
+ }
+
+
+ public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
+ {
+ var sum = keras.backend.binary_crossentropy(y_true, y_pred, from_logits: from_logits);
+ return keras.backend.mean(sum, axis: axis);
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs
index c80b1a83..feb05224 100644
--- a/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs
+++ b/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs
@@ -1,31 +1,24 @@
-using System;
-using System.Collections.Generic;
-using System.Text;
-using static Tensorflow.Binding;
-using static Tensorflow.KerasApi;
+namespace Tensorflow.Keras.Losses;
-namespace Tensorflow.Keras.Losses
+public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc
{
- public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc
+ float label_smoothing;
+ public CategoricalCrossentropy(
+ bool from_logits = false,
+ float label_smoothing = 0,
+ string reduction = null,
+ string name = null) :
+ base(reduction: reduction,
+ name: name == null ? "categorical_crossentropy" : name,
+ from_logits: from_logits)
{
- float label_smoothing;
- public CategoricalCrossentropy(
- bool from_logits = false,
- float label_smoothing = 0,
- string reduction = null,
- string name = null) :
- base(reduction: reduction,
- name: name == null ? "categorical_crossentropy" : name,
- from_logits: from_logits)
- {
- this.label_smoothing = label_smoothing;
- }
+ this.label_smoothing = label_smoothing;
+ }
- public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
- {
- // Try to adjust the shape so that rank of labels = rank of logits - 1.
- return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits);
- }
+ public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
+ {
+ // Try to adjust the shape so that rank of labels = rank of logits - 1.
+ return keras.backend.categorical_crossentropy(y_true, y_pred, from_logits: from_logits);
}
}
diff --git a/src/TensorFlowNET.Keras/Losses/ILossFunc.cs b/src/TensorFlowNET.Keras/Losses/ILossFunc.cs
deleted file mode 100644
index 8bc226df..00000000
--- a/src/TensorFlowNET.Keras/Losses/ILossFunc.cs
+++ /dev/null
@@ -1,9 +0,0 @@
-namespace Tensorflow.Keras.Losses
-{
- public interface ILossFunc
- {
- public string Reduction { get; }
- public string Name { get; }
- Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null);
- }
-}
diff --git a/src/TensorFlowNET.Keras/Losses/Loss.cs b/src/TensorFlowNET.Keras/Losses/Loss.cs
index fe017ac4..77bf7e1d 100644
--- a/src/TensorFlowNET.Keras/Losses/Loss.cs
+++ b/src/TensorFlowNET.Keras/Losses/Loss.cs
@@ -16,7 +16,7 @@ namespace Tensorflow.Keras.Losses
public string Reduction => reduction;
public string Name => name;
- public Loss(string reduction = ReductionV2.AUTO,
+ public Loss(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
{
@@ -34,7 +34,17 @@ namespace Tensorflow.Keras.Losses
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
var losses = Apply(y_true, y_pred, from_logits: from_logits);
- return losses_utils.compute_weighted_loss(losses, reduction: this.reduction , sample_weight: sample_weight);
+ var reduction = GetReduction();
+ return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight);
+ }
+
+ string GetReduction()
+ {
+ return reduction switch
+ {
+ ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE,
+ _ => reduction
+ };
}
void _set_name_scope()
diff --git a/src/TensorFlowNET.Keras/Losses/LossesApi.cs b/src/TensorFlowNET.Keras/Losses/LossesApi.cs
index 71cffebb..29e15e53 100644
--- a/src/TensorFlowNET.Keras/Losses/LossesApi.cs
+++ b/src/TensorFlowNET.Keras/Losses/LossesApi.cs
@@ -1,7 +1,17 @@
namespace Tensorflow.Keras.Losses
{
- public class LossesApi
+ public class LossesApi : ILossesApi
{
+ public ILossFunc BinaryCrossentropy(bool from_logits = false,
+ float label_smoothing = 0,
+ int axis = -1,
+ string reduction = "auto",
+ string name = "binary_crossentropy")
+ => new BinaryCrossentropy(from_logits: from_logits,
+ label_smoothing: label_smoothing,
+ reduction: reduction,
+ name: name);
+
public ILossFunc SparseCategoricalCrossentropy(string reduction = null, string name = null,bool from_logits = false)
=> new SparseCategoricalCrossentropy(reduction: reduction, name: name,from_logits: from_logits);
@@ -19,14 +29,13 @@
public ILossFunc MeanAbsoluteError(string reduction = null, string name = null)
=> new MeanAbsoluteError(reduction: reduction, name: name);
- public ILossFunc CosineSimilarity(string reduction = null, string name = null,int axis=-1)
- => new CosineSimilarity(reduction: reduction, name: name, axis: axis);
+ public ILossFunc CosineSimilarity(string reduction = null, int axis = -1, string name = null)
+ => new CosineSimilarity(reduction: reduction, axis: axis, name: name);
public ILossFunc Huber(string reduction = null, string name = null, Tensor delta=null)
=> new Huber(reduction: reduction, name: name, delta: delta);
public ILossFunc LogCosh(string reduction = null, string name = null)
=> new LogCosh(reduction: reduction, name: name);
-
}
}
diff --git a/src/TensorFlowNET.Keras/Utils/losses_utils.cs b/src/TensorFlowNET.Keras/Utils/losses_utils.cs
index 8a8772fd..08330595 100644
--- a/src/TensorFlowNET.Keras/Utils/losses_utils.cs
+++ b/src/TensorFlowNET.Keras/Utils/losses_utils.cs
@@ -24,23 +24,17 @@ namespace Tensorflow.Keras.Utils
{
public static Tensor compute_weighted_loss(Tensor losses, Tensor sample_weight = null, string reduction = null, string name = null)
{
- if (sample_weight == null)
- sample_weight = losses.dtype == TF_DataType.TF_DOUBLE ? tf.constant(1.0) : tf.constant(1.0f);
- var weighted_losses = scale_losses_by_sample_weight(losses, sample_weight);
- // Apply reduction function to the individual weighted losses.
- var loss = reduce_weighted_loss(weighted_losses, reduction);
- // Convert the result back to the input type.
- // loss = math_ops.cast(loss, losses.dtype);
- return loss;
- }
-
- public static Tensor scale_losses_by_sample_weight(Tensor losses, Tensor sample_weight)
- {
- // losses = math_ops.cast(losses, dtypes.float32);
- // sample_weight = math_ops.cast(sample_weight, dtypes.float32);
- // Update dimensions of `sample_weight` to match with `losses` if possible.
- // (losses, sample_weight) = squeeze_or_expand_dimensions(losses, sample_weight);
- return math_ops.multiply(losses, sample_weight);
+ return tf_with(ops.name_scope("weighted_loss"), scope =>
+ {
+ if (sample_weight == null)
+ sample_weight = losses.dtype == TF_DataType.TF_DOUBLE ? tf.constant(1.0) : tf.constant(1.0f);
+ var weighted_losses = math_ops.multiply(losses, sample_weight);
+ // Apply reduction function to the individual weighted losses.
+ var loss = reduce_weighted_loss(weighted_losses, reduction);
+ // Convert the result back to the input type.
+ // loss = math_ops.cast(loss, losses.dtype);
+ return loss;
+ });
}
public static (Tensor, Tensor) squeeze_or_expand_dimensions(Tensor y_pred, Tensor sample_weight)
diff --git a/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs b/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs
new file mode 100644
index 00000000..dad46c55
--- /dev/null
+++ b/test/TensorFlowNET.Keras.UnitTest/Losses/LossesTest.cs
@@ -0,0 +1,50 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using System.Threading.Tasks;
+using TensorFlowNET.Keras.UnitTest;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow.Keras.UnitTest.Losses;
+
+[TestClass]
+public class LossesTest : EagerModeTestBase
+{
+ ///
+ /// https://www.tensorflow.org/api_docs/python/tf/keras/losses/BinaryCrossentropy
+ ///
+ [TestMethod]
+ public void BinaryCrossentropy()
+ {
+ // Example 1: (batch_size = 1, number of samples = 4)
+ var y_true = tf.constant(new float[] { 0, 1, 0, 0 });
+ var y_pred = tf.constant(new float[] { -18.6f, 0.51f, 2.94f, -12.8f });
+ var bce = tf.keras.losses.BinaryCrossentropy(from_logits: true);
+ var loss = bce.Call(y_true, y_pred);
+ Assert.AreEqual((float)loss, 0.865458f);
+
+ // Example 2: (batch_size = 2, number of samples = 4)
+ y_true = tf.constant(new float[,] { { 0, 1 }, { 0, 0 } });
+ y_pred = tf.constant(new float[,] { { -18.6f, 0.51f }, { 2.94f, -12.8f } });
+ bce = tf.keras.losses.BinaryCrossentropy(from_logits: true);
+ loss = bce.Call(y_true, y_pred);
+ Assert.AreEqual((float)loss, 0.865458f);
+
+ // Using 'sample_weight' attribute
+ loss = bce.Call(y_true, y_pred, sample_weight: tf.constant(new[] { 0.8f, 0.2f }));
+ Assert.AreEqual((float)loss, 0.2436386f);
+
+ // Using 'sum' reduction` type.
+ bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.SUM);
+ loss = bce.Call(y_true, y_pred);
+ Assert.AreEqual((float)loss, 1.730916f);
+
+ // Using 'none' reduction type.
+ bce = tf.keras.losses.BinaryCrossentropy(from_logits: true, reduction: Reduction.NONE);
+ loss = bce.Call(y_true, y_pred);
+ Assert.AreEqual(new float[] { 0.23515666f, 1.4957594f}, loss.numpy());
+ }
+}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj
index 61e522e6..c9020f7b 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj
+++ b/test/TensorFlowNET.Keras.UnitTest/Tensorflow.Keras.UnitTest.csproj
@@ -1,4 +1,4 @@
-
+
net6.0