diff --git a/README.md b/README.md
index 5ee82c93..2fd46c6b 100644
--- a/README.md
+++ b/README.md
@@ -56,30 +56,32 @@ PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU
Import TF.NET and Keras API in your project.
-```cs
+```csharp
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
+using Tensorflow;
+using NumSharp;
```
Linear Regression in `Eager` mode:
-```c#
+```csharp
// Parameters
var training_steps = 1000;
var learning_rate = 0.01f;
var display_step = 100;
// Sample data
-var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
+var X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f);
-var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
+var Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
-var n_samples = train_X.shape[0];
+var n_samples = X.shape[0];
// We can set a fixed init value in order to demo
var W = tf.Variable(-0.06f, name: "weight");
var b = tf.Variable(-0.73f, name: "bias");
-var optimizer = tf.optimizers.SGD(learning_rate);
+var optimizer = keras.optimizers.SGD(learning_rate);
// Run training for the given number of steps.
foreach (var step in range(1, training_steps + 1))
@@ -112,46 +114,40 @@ Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube)
Toy version of `ResNet` in `Keras` functional API:
```csharp
+var layers = new LayersApi();
// input layer
var inputs = keras.Input(shape: (32, 32, 3), name: "img");
-
// convolutional layer
var x = layers.Conv2D(32, 3, activation: "relu").Apply(inputs);
x = layers.Conv2D(64, 3, activation: "relu").Apply(x);
var block_1_output = layers.MaxPooling2D(3).Apply(x);
-
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_1_output);
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x);
-var block_2_output = layers.add(x, block_1_output);
-
+var block_2_output = layers.Add().Apply(new Tensors(x, block_1_output));
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_2_output);
x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x);
-var block_3_output = layers.add(x, block_2_output);
-
+var block_3_output = layers.Add().Apply(new Tensors(x, block_2_output));
x = layers.Conv2D(64, 3, activation: "relu").Apply(block_3_output);
x = layers.GlobalAveragePooling2D().Apply(x);
x = layers.Dense(256, activation: "relu").Apply(x);
x = layers.Dropout(0.5f).Apply(x);
-
// output layer
var outputs = layers.Dense(10).Apply(x);
-
// build keras model
-model = keras.Model(inputs, outputs, name: "toy_resnet");
+var model = keras.Model(inputs, outputs, name: "toy_resnet");
model.summary();
-
// compile keras model in tensorflow static graph
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
- loss: keras.losses.CategoricalCrossentropy(from_logits: true),
- metrics: new[] { "acc" });
-
+ loss: keras.losses.CategoricalCrossentropy(from_logits: true),
+ metrics: new[] { "acc" });
// prepare dataset
var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
-
+x_train = x_train / 255.0f;
+y_train = np_utils.to_categorical(y_train, 10);
// training
-model.fit(x_train[new Slice(0, 1000)], y_train[new Slice(0, 1000)],
- batch_size: 64,
- epochs: 10,
+model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)],
+ batch_size: 64,
+ epochs: 10,
validation_split: 0.2f);
```
@@ -260,4 +256,4 @@ WeChat Sponsor 微信打赏:
TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/)
-
+
\ No newline at end of file
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs
index b0d741f1..5c335b81 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.cs
@@ -506,6 +506,27 @@ namespace Tensorflow
}
}
+
+ public static Tensor where_v2(Tensor condition, object x = null, object y = null, string name = null)
+ {
+ if (x == null && y == null)
+ {
+ return tf_with(ops.name_scope(name, "Where", new { condition }), scope =>
+ {
+ name = scope;
+ condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition");
+ return gen_array_ops.where(condition: condition, name: name);
+ });
+ }
+ else if (x != null && y != null)
+ {
+ return gen_array_ops.select_v2(condition, x, y, name);
+ }
+ else
+ {
+ throw new ValueError("x and y must both be non-None or both be None.");
+ }
+ }
///
/// Returns the shape of a tensor.
///
diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
index 50dde5c3..019d19be 100644
--- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs
@@ -423,6 +423,21 @@ namespace Tensorflow
var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t = x, e = y });
return _op.outputs[0];
}
+ public static Tensor select_v2(Tensor condition, Tx x, Ty y, string name = null)
+ {
+ if (tf.Context.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ "SelectV2", name,
+ null,
+ condition, x, y);
+
+ return results[0];
+ }
+
+ var _op = tf.OpDefLib._apply_op_helper("SelectV2", name, new { condition, t = x, e = y });
+ return _op.outputs[0];
+ }
public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor[] shape, string name = null)
{
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index 67b5b614..ca74ea5f 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -714,7 +714,23 @@ namespace Tensorflow
return _op.outputs[0];
}
+ public static Tensor softplus(Tensor features, string name = null)
+ {
+ if (tf.Context.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ "Softplus", name,
+ null,
+ features);
+ return results[0];
+ }
+
+ var _op = tf.OpDefLib._apply_op_helper("Softplus", name, args: new { features });
+
+ return _op.outputs[0];
+ }
+
public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string name = null)
=> tf.Context.RunInAutoMode(()
=> tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, ()
@@ -1068,6 +1084,15 @@ namespace Tensorflow
public static Tensor _abs(Tensor x, string name = null)
{
+ if (tf.Context.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ "Abs", name,
+ null,
+ x);
+
+ return results[0];
+ }
var _op = tf.OpDefLib._apply_op_helper("Abs", name, args: new { x });
return _op.output;
@@ -1202,6 +1227,15 @@ namespace Tensorflow
///
public static Tensor rsqrt(Tensor x, string name = null)
{
+ if (tf.Context.executing_eagerly())
+ {
+ var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
+ "Rsqrt", name,
+ null,
+ x);
+
+ return results[0];
+ }
var _op = tf.OpDefLib._apply_op_helper("Rsqrt", name, new { x });
return _op.outputs[0];
diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
index 121a4728..7b008e4e 100644
--- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
+++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs
@@ -31,7 +31,7 @@ namespace Tensorflow
///
public static Tensor l2_normalize(Tensor x,
int axis = 0,
- float epsilon = 1e-12f,
+ Tensor epsilon =null,
string name = null)
{
return tf_with(ops.name_scope(name, "l2_normalize", new { x }), scope =>
@@ -39,7 +39,7 @@ namespace Tensorflow
x = ops.convert_to_tensor(x, name: "x");
var sq = math_ops.square(x);
var square_sum = math_ops.reduce_sum(sq, axis, keepdims: true);
- var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon));
+ var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon == null ? tf.Variable(1e-12f) : epsilon));
return math_ops.multiply(x, x_inv_norm, name: name);
});
}
diff --git a/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs
index bf5dbb64..c80b1a83 100644
--- a/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs
+++ b/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs
@@ -9,18 +9,19 @@ namespace Tensorflow.Keras.Losses
public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc
{
float label_smoothing;
-
- public CategoricalCrossentropy(bool from_logits = false,
+ public CategoricalCrossentropy(
+ bool from_logits = false,
float label_smoothing = 0,
- string reduction = ReductionV2.AUTO,
- string name = "categorical_crossentropy") :
- base(reduction: reduction,
- name: name,
- from_logits: from_logits)
+ string reduction = null,
+ string name = null) :
+ base(reduction: reduction,
+ name: name == null ? "categorical_crossentropy" : name,
+ from_logits: from_logits)
{
this.label_smoothing = label_smoothing;
}
+
public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
{
// Try to adjust the shape so that rank of labels = rank of logits - 1.
diff --git a/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs b/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs
new file mode 100644
index 00000000..57debbc9
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow.Keras.Losses
+{
+ public class CosineSimilarity : LossFunctionWrapper, ILossFunc
+ {
+ protected int axis=-1;
+ public CosineSimilarity(
+ string reduction = null,
+ int axis=-1,
+ string name = null) :
+ base(reduction: reduction, name: name == null ? "cosine_similarity" : name)
+ {
+ this.axis = axis;
+ }
+
+ public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
+ {
+ Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis : this.axis);
+ Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis);
+ return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis : this.axis);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Losses/Huber.cs b/src/TensorFlowNET.Keras/Losses/Huber.cs
new file mode 100644
index 00000000..6098dee3
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Losses/Huber.cs
@@ -0,0 +1,36 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow.Keras.Losses
+{
+ public class Huber : LossFunctionWrapper, ILossFunc
+ {
+ protected Tensor delta = tf.Variable(1.0) ;
+ public Huber (
+ string reduction = null,
+ Tensor delta = null,
+ string name = null) :
+ base(reduction: reduction, name: name == null ? "huber" : name)
+ {
+ this.delta = delta==null? this.delta: delta;
+
+ }
+
+ public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
+ {
+ Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT);
+ Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT);
+ Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT);
+ Tensor error = math_ops.subtract(y_pred_cast, y_true_cast);
+ Tensor abs_error = math_ops.abs(error);
+ Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype);
+ return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta,
+ half * math_ops.pow(error, 2),
+ half * math_ops.pow(delta, 2) + delta * (abs_error - delta)),
+ axis : -1);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Losses/ILossFunc.cs b/src/TensorFlowNET.Keras/Losses/ILossFunc.cs
index 45c39dd2..8bc226df 100644
--- a/src/TensorFlowNET.Keras/Losses/ILossFunc.cs
+++ b/src/TensorFlowNET.Keras/Losses/ILossFunc.cs
@@ -2,7 +2,8 @@
{
public interface ILossFunc
{
- string Reduction { get; }
- Tensor Call(Tensor y_true, Tensor y_pred);
+ public string Reduction { get; }
+ public string Name { get; }
+ Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null);
}
}
diff --git a/src/TensorFlowNET.Keras/Losses/LogCosh.cs b/src/TensorFlowNET.Keras/Losses/LogCosh.cs
new file mode 100644
index 00000000..6db10bc8
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Losses/LogCosh.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using Tensorflow.Operations;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow.Keras.Losses
+{
+ public class LogCosh : LossFunctionWrapper, ILossFunc
+ {
+ public LogCosh(
+ string reduction = null,
+ string name = null) :
+ base(reduction: reduction, name: name == null ? "huber" : name){ }
+
+ public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
+ {
+ Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
+ Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
+ Tensor x = y_pred_dispatch - y_true_cast;
+
+ return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),axis: -1);
+
+
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Losses/Loss.cs b/src/TensorFlowNET.Keras/Losses/Loss.cs
index 54b2b249..fe017ac4 100644
--- a/src/TensorFlowNET.Keras/Losses/Loss.cs
+++ b/src/TensorFlowNET.Keras/Losses/Loss.cs
@@ -15,12 +15,12 @@ namespace Tensorflow.Keras.Losses
string _name_scope;
public string Reduction => reduction;
-
+ public string Name => name;
public Loss(string reduction = ReductionV2.AUTO,
string name = null,
bool from_logits = false)
{
- this.reduction = reduction;
+ this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction;
this.name = name;
this.from_logits = from_logits;
_allow_sum_over_batch_size = false;
@@ -31,10 +31,10 @@ namespace Tensorflow.Keras.Losses
throw new NotImplementedException("");
}
- public Tensor Call(Tensor y_true, Tensor y_pred)
+ public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
var losses = Apply(y_true, y_pred, from_logits: from_logits);
- return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE);
+ return losses_utils.compute_weighted_loss(losses, reduction: this.reduction , sample_weight: sample_weight);
}
void _set_name_scope()
diff --git a/src/TensorFlowNET.Keras/Losses/LossesApi.cs b/src/TensorFlowNET.Keras/Losses/LossesApi.cs
index 3e66b395..71cffebb 100644
--- a/src/TensorFlowNET.Keras/Losses/LossesApi.cs
+++ b/src/TensorFlowNET.Keras/Losses/LossesApi.cs
@@ -2,10 +2,31 @@
{
public class LossesApi
{
- public ILossFunc SparseCategoricalCrossentropy(bool from_logits = false)
- => new SparseCategoricalCrossentropy(from_logits: from_logits);
+ public ILossFunc SparseCategoricalCrossentropy(string reduction = null, string name = null,bool from_logits = false)
+ => new SparseCategoricalCrossentropy(reduction: reduction, name: name,from_logits: from_logits);
+
+ public ILossFunc CategoricalCrossentropy(string reduction = null, string name = null,bool from_logits = false)
+ => new CategoricalCrossentropy(reduction: reduction, name: name,from_logits: from_logits);
+
+ public ILossFunc MeanSquaredError(string reduction = null, string name = null)
+ => new MeanSquaredError(reduction: reduction, name:name);
+ public ILossFunc MeanSquaredLogarithmicError(string reduction = null, string name = null)
+ => new MeanSquaredLogarithmicError(reduction: reduction, name: name);
+
+ public ILossFunc MeanAbsolutePercentageError(string reduction = null, string name = null)
+ => new MeanAbsolutePercentageError(reduction: reduction, name: name);
+
+ public ILossFunc MeanAbsoluteError(string reduction = null, string name = null)
+ => new MeanAbsoluteError(reduction: reduction, name: name);
+
+ public ILossFunc CosineSimilarity(string reduction = null, string name = null,int axis=-1)
+ => new CosineSimilarity(reduction: reduction, name: name, axis: axis);
+
+ public ILossFunc Huber(string reduction = null, string name = null, Tensor delta=null)
+ => new Huber(reduction: reduction, name: name, delta: delta);
+
+ public ILossFunc LogCosh(string reduction = null, string name = null)
+ => new LogCosh(reduction: reduction, name: name);
- public ILossFunc CategoricalCrossentropy(bool from_logits = false)
- => new CategoricalCrossentropy(from_logits: from_logits);
}
}
diff --git a/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs b/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs
new file mode 100644
index 00000000..5d0f83d4
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs
@@ -0,0 +1,23 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow.Keras.Losses
+{
+ public class MeanAbsoluteError : LossFunctionWrapper, ILossFunc
+ {
+ public MeanAbsoluteError(
+ string reduction = null,
+ string name = null) :
+ base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ }
+
+ public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
+ {
+ Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
+ Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
+ return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), axis: -1);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs
new file mode 100644
index 00000000..74c95b4a
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs
@@ -0,0 +1,24 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow.Keras.Losses
+{
+ public class MeanAbsolutePercentageError : LossFunctionWrapper, ILossFunc
+ {
+ public MeanAbsolutePercentageError(
+ string reduction = null,
+ string name = null) :
+ base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ }
+
+ public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
+ {
+ Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
+ Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
+ Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype));
+ return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) *gen_math_ops.mean(diff, axis: -1);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs
new file mode 100644
index 00000000..24ef1043
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs
@@ -0,0 +1,23 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow.Keras.Losses
+{
+ public class MeanSquaredError : LossFunctionWrapper, ILossFunc
+ {
+ public MeanSquaredError(
+ string reduction = null,
+ string name = null) :
+ base(reduction: reduction, name: name==null? "mean_squared_error" : name){ }
+
+ public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
+ {
+ Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
+ Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
+ return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs
new file mode 100644
index 00000000..22b5a6ff
--- /dev/null
+++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs
@@ -0,0 +1,33 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace Tensorflow.Keras.Losses
+{
+ public class MeanSquaredLogarithmicError : LossFunctionWrapper, ILossFunc
+ {
+ public MeanSquaredLogarithmicError(
+ string reduction = null,
+ string name = null) :
+ base(reduction: reduction, name: name == null ? "mean_squared_logarithmic_error" : name){ }
+
+
+ public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
+ {
+ Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
+ Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
+ Tensor first_log=null, second_log=null;
+ if (y_pred_dispatch.dtype == TF_DataType.TF_DOUBLE) {
+ first_log = math_ops.log(gen_math_ops.maximum(y_pred_dispatch, 1e-7) + 1.0);
+ second_log = math_ops.log(gen_math_ops.maximum(y_true_cast, 1e-7) + 1.0);
+ }
+ else {
+ first_log = math_ops.log(gen_math_ops.maximum(y_pred_dispatch, 1e-7f) + 1.0f);
+ second_log = math_ops.log(gen_math_ops.maximum(y_true_cast, 1e-7f) + 1.0f);
+ }
+ return gen_math_ops.mean(gen_math_ops.squared_difference(first_log, second_log), axis: -1);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Keras/Losses/ReductionV2.cs b/src/TensorFlowNET.Keras/Losses/ReductionV2.cs
index afe2006d..4b6cbbfd 100644
--- a/src/TensorFlowNET.Keras/Losses/ReductionV2.cs
+++ b/src/TensorFlowNET.Keras/Losses/ReductionV2.cs
@@ -4,6 +4,7 @@
{
public const string NONE = "none";
public const string AUTO = "auto";
+ public const string SUM = "sum";
public const string SUM_OVER_BATCH_SIZE = "sum_over_batch_size";
public const string WEIGHTED_MEAN = "weighted_mean";
}
diff --git a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs
index fe14e887..2cf24fc3 100644
--- a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs
+++ b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs
@@ -4,14 +4,11 @@ namespace Tensorflow.Keras.Losses
{
public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc
{
- public SparseCategoricalCrossentropy(bool from_logits = false,
- string reduction = ReductionV2.AUTO,
- string name = "sparse_categorical_crossentropy") :
- base(reduction: reduction,
- name: name)
- {
-
- }
+ public SparseCategoricalCrossentropy(
+ bool from_logits = false,
+ string reduction = null,
+ string name = null) :
+ base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ }
public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1)
{
diff --git a/test/TensorFlowNET.UnitTest/Keras/CosineSimilarity.Test.cs b/test/TensorFlowNET.UnitTest/Keras/CosineSimilarity.Test.cs
new file mode 100644
index 00000000..70e07264
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/Keras/CosineSimilarity.Test.cs
@@ -0,0 +1,76 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using Tensorflow;
+using Tensorflow.Keras.Losses;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace TensorFlowNET.UnitTest.Keras
+{
+ [TestClass]
+ public class CosineSimilarity
+ {
+ //https://keras.io/api/losses/regression_losses/
+
+ NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 1.0f, 1.0f } };
+ NDArray y_pred_float = new float[,] { { 1.0f, 0.0f }, { 1.0f, 1.0f } };
+
+ [TestMethod]
+
+ public void _Default()
+ {
+ //>>> # Using 'auto'/'sum_over_batch_size' reduction type.
+ //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1)
+ //>>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]]
+ //>>> # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]]
+ //>>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]]
+ //>>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))
+ //>>> # = -((0. + 0.) + (0.5 + 0.5)) / 2
+ //-0.5
+ var loss = keras.losses.CosineSimilarity(axis : 1);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)(-0.49999997f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _Sample_Weight()
+ {
+ //>>> # Calling with 'sample_weight'.
+ //>>> cosine_loss(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy()
+ //- 0.0999
+ var loss = keras.losses.CosineSimilarity();
+ var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f });
+ Assert.AreEqual((NDArray) (- 0.099999994f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _SUM()
+ {
+ //>>> # Using 'sum' reduction type.
+ //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1,
+ //... reduction = tf.keras.losses.Reduction.SUM)
+ //>>> cosine_loss(y_true, y_pred).numpy()
+ //- 0.999
+ var loss = keras.losses.CosineSimilarity(axis: 1,reduction : ReductionV2.SUM);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)(-0.99999994f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _None()
+ {
+ //>>> # Using 'none' reduction type.
+ //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1,
+ //... reduction = tf.keras.losses.Reduction.NONE)
+ //>>> cosine_loss(y_true, y_pred).numpy()
+ //array([-0., -0.999], dtype = float32)
+ var loss = keras.losses.CosineSimilarity(axis :1, reduction: ReductionV2.NONE);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)new float[] { -0f, -0.99999994f }, call.numpy());
+ }
+
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/Keras/Huber.Test.cs b/test/TensorFlowNET.UnitTest/Keras/Huber.Test.cs
new file mode 100644
index 00000000..cbc16eab
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/Keras/Huber.Test.cs
@@ -0,0 +1,72 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using Tensorflow;
+using Tensorflow.Keras.Losses;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace TensorFlowNET.UnitTest.Keras
+{
+ [TestClass]
+ public class Huber
+ {
+ //https://keras.io/api/losses/regression_losses/#meansquarederror-class
+
+ NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } };
+ NDArray y_pred_float = new float[,] { { 0.6f, 0.4f }, { 0.4f, 0.6f } };
+
+ [TestMethod]
+
+ public void _Default()
+ {
+ //>>> # Using 'auto'/'sum_over_batch_size' reduction type.
+ //>>> h = tf.keras.losses.Huber()
+ //>>> h(y_true, y_pred).numpy()
+ //0.155
+ var loss = keras.losses.Huber();
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)0.155f, call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _Sample_Weight()
+ {
+ //>>> # Calling with 'sample_weight'.
+ //>>> h(y_true, y_pred, sample_weight =[1, 0]).numpy()
+ //0.09
+ var loss = keras.losses.Huber();
+ var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.1f, 0.0f });
+ Assert.AreEqual((NDArray)0.009000001f, call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _SUM()
+ {
+ //>>> # Using 'sum' reduction type.
+ //>>> h = tf.keras.losses.Huber(
+ //... reduction = tf.keras.losses.Reduction.SUM)
+ //>>> h(y_true, y_pred).numpy()
+ //0.31
+ var loss = keras.losses.Huber(reduction : ReductionV2.SUM);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)0.31f, call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _None()
+ {
+ //>>> # Using 'none' reduction type.
+ //>>> h = tf.keras.losses.Huber(
+ //... reduction = tf.keras.losses.Reduction.NONE)
+ //>>> h(y_true, y_pred).numpy()
+ //array([0.18, 0.13], dtype = float32)
+ var loss = keras.losses.Huber(reduction: ReductionV2.NONE);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)new float[] { 0.18f, 0.13000001f }, call.numpy());
+ }
+
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/Keras/LogCosh.Test.cs b/test/TensorFlowNET.UnitTest/Keras/LogCosh.Test.cs
new file mode 100644
index 00000000..48d2d862
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/Keras/LogCosh.Test.cs
@@ -0,0 +1,72 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using Tensorflow;
+using Tensorflow.Keras.Losses;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace TensorFlowNET.UnitTest.Keras
+{
+ [TestClass]
+ public class LogCosh
+ {
+ //https://keras.io/api/losses/regression_losses/#meansquarederror-class
+
+ NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } };
+ NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 0.0f, 0.0f } };
+
+ [TestMethod]
+
+ public void _Default()
+ {
+ //>>> # Using 'auto'/'sum_over_batch_size' reduction type.
+ //>>> l = tf.keras.losses.LogCosh()
+ //>>> l(y_true, y_pred).numpy()
+ //0.108
+ var loss = keras.losses.LogCosh();
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)0.1084452f, call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _Sample_Weight()
+ {
+ //>>> # Calling with 'sample_weight'.
+ //>>> l(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy()
+ //0.087
+ var loss = keras.losses.LogCosh();
+ var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f });
+ Assert.AreEqual((NDArray)0.08675616f, call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _SUM()
+ {
+ //>>> # Using 'sum' reduction type.
+ //>>> l = tf.keras.losses.LogCosh(
+ //... reduction = tf.keras.losses.Reduction.SUM)
+ //>>> l(y_true, y_pred).numpy()
+ //0.217
+ var loss = keras.losses.LogCosh(reduction : ReductionV2.SUM);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)0.2168904f, call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _None()
+ {
+ //>>> # Using 'none' reduction type.
+ //>>> l = tf.keras.losses.LogCosh(
+ //... reduction = tf.keras.losses.Reduction.NONE)
+ //>>> l(y_true, y_pred).numpy()
+ //array([0.217, 0.], dtype = float32)
+ var loss = keras.losses.LogCosh(reduction: ReductionV2.NONE);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)new float[] { 0.2168904f, 0.0f }, call.numpy());
+ }
+
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/Keras/MeanAbsoluteError.Test.cs b/test/TensorFlowNET.UnitTest/Keras/MeanAbsoluteError.Test.cs
new file mode 100644
index 00000000..2b7a2504
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/Keras/MeanAbsoluteError.Test.cs
@@ -0,0 +1,73 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using Tensorflow;
+using Tensorflow.Keras.Losses;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace TensorFlowNET.UnitTest.Keras
+{
+ [TestClass]
+ public class MeanAbsoluteError
+ {
+ //https://keras.io/api/losses/regression_losses/
+
+ NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } };
+ NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } };
+
+ [TestMethod]
+
+ public void _Default()
+ {
+
+ //>>> # Using 'auto'/'sum_over_batch_size' reduction type.
+ //>>> mae = tf.keras.losses.MeanAbsoluteError()
+ //>>> mae(y_true, y_pred).numpy()
+ //0.5
+ var loss = keras.losses.MeanAbsoluteError();
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)(0.5f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _Sample_Weight()
+ {
+ //>>> # Calling with 'sample_weight'.
+ //>>> mae(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy()
+ //0.25
+ var loss = keras.losses.MeanAbsoluteError();
+ var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f });
+ Assert.AreEqual((NDArray)(0.25f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _SUM()
+ {
+ //>>> # Using 'sum' reduction type.
+ //>>> mae = tf.keras.losses.MeanAbsoluteError(
+ //... reduction = tf.keras.losses.Reduction.SUM)
+ //>>> mae(y_true, y_pred).numpy()
+ //1.0
+ var loss = keras.losses.MeanAbsoluteError( reduction: ReductionV2.SUM);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)(1.0f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _None()
+ {
+ //>>> # Using 'none' reduction type.
+ //>>> mae = tf.keras.losses.MeanAbsoluteError(
+ //... reduction = tf.keras.losses.Reduction.NONE)
+ //>>> mae(y_true, y_pred).numpy()
+ //array([0.5, 0.5], dtype = float32)
+ var loss = keras.losses.MeanAbsoluteError(reduction: ReductionV2.NONE);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)new float[] { 0.5f, 0.5f }, call.numpy());
+ }
+
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/Keras/MeanAbsolutePercentageError.Test.cs b/test/TensorFlowNET.UnitTest/Keras/MeanAbsolutePercentageError.Test.cs
new file mode 100644
index 00000000..97b43503
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/Keras/MeanAbsolutePercentageError.Test.cs
@@ -0,0 +1,72 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using Tensorflow;
+using Tensorflow.Keras.Losses;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace TensorFlowNET.UnitTest.Keras
+{
+ [TestClass]
+ public class MeanAbsolutePercentageError
+ {
+ //https://keras.io/api/losses/regression_losses/
+
+ NDArray y_true_float = new float[,] { { 2.0f, 1.0f }, { 2.0f, 3.0f } };
+ NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } };
+
+ [TestMethod]
+
+ public void _Default()
+ {
+ //>>> # Using 'auto'/'sum_over_batch_size' reduction type.
+ //>>> mape = tf.keras.losses.MeanAbsolutePercentageError()
+ //>>> mape(y_true, y_pred).numpy()
+ //50.
+ var loss = keras.losses.MeanAbsolutePercentageError();
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)(50f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _Sample_Weight()
+ {
+ //>>> # Calling with 'sample_weight'.
+ //>>> mape(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy()
+ //20.
+ var loss = keras.losses.MeanAbsolutePercentageError();
+ var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f });
+ Assert.AreEqual((NDArray)(20f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _SUM()
+ {
+ //>>> # Using 'sum' reduction type.
+ //>>> mape = tf.keras.losses.MeanAbsolutePercentageError(
+ //... reduction = tf.keras.losses.Reduction.SUM)
+ //>>> mape(y_true, y_pred).numpy()
+ //100.
+ var loss = keras.losses.MeanAbsolutePercentageError( reduction: ReductionV2.SUM);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)(100f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _None()
+ {
+ //>>> # Using 'none' reduction type.
+ //>>> mape = tf.keras.losses.MeanAbsolutePercentageError(
+ //... reduction = tf.keras.losses.Reduction.NONE)
+ //>>> mape(y_true, y_pred).numpy()
+ //array([25., 75.], dtype = float32)
+ var loss = keras.losses.MeanAbsolutePercentageError(reduction: ReductionV2.NONE);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)new float[] { 25f, 75f }, call.numpy());
+ }
+
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/Keras/MeanSquaredError.Test.cs b/test/TensorFlowNET.UnitTest/Keras/MeanSquaredError.Test.cs
new file mode 100644
index 00000000..f1c782f8
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/Keras/MeanSquaredError.Test.cs
@@ -0,0 +1,65 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using Tensorflow;
+using Tensorflow.Keras.Losses;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace TensorFlowNET.UnitTest.Keras
+{
+ [TestClass]
+ public class MeanSquaredErrorTest
+ {
+ //https://keras.io/api/losses/regression_losses/#meansquarederror-class
+
+ private NDArray y_true = new double[,] { { 0.0, 1.0 }, { 0.0, 0.0 } };
+ private NDArray y_pred = new double[,] { { 1.0, 1.0 }, { 1.0, 0.0 } };
+
+ [TestMethod]
+
+ public void Mse_Double()
+ {
+ var mse = keras.losses.MeanSquaredError();
+ var call = mse.Call(y_true, y_pred);
+ Assert.AreEqual((NDArray)0.5, call.numpy()) ;
+ }
+
+ [TestMethod]
+
+ public void Mse_Float()
+ {
+ NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } };
+ NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } };
+
+ var mse = keras.losses.MeanSquaredError();
+ var call = mse.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)0.5, call.numpy());
+ }
+
+ [TestMethod]
+
+ public void Mse_Sample_Weight()
+ {
+ var mse = keras.losses.MeanSquaredError();
+ var call = mse.Call(y_true, y_pred, sample_weight: (NDArray)new double[] { 0.7, 0.3 });
+ Assert.AreEqual((NDArray)0.25, call.numpy());
+ }
+
+ [TestMethod]
+ public void Mse_Reduction_SUM()
+ {
+ var mse = keras.losses.MeanSquaredError(reduction: Reduction.SUM);
+ var call = mse.Call(y_true, y_pred);
+ Assert.AreEqual((NDArray)1.0, call.numpy());
+ }
+
+ [TestMethod]
+
+ public void Mse_Reduction_NONE()
+ {
+ var mse = keras.losses.MeanSquaredError(reduction: Reduction.NONE);
+ var call = mse.Call(y_true, y_pred);
+ Assert.AreEqual((NDArray)new double[] { 0.5, 0.5 }, call.numpy());
+ }
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/Keras/MeanSquaredLogarithmicError.Test.cs b/test/TensorFlowNET.UnitTest/Keras/MeanSquaredLogarithmicError.Test.cs
new file mode 100644
index 00000000..28499143
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/Keras/MeanSquaredLogarithmicError.Test.cs
@@ -0,0 +1,72 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using NumSharp;
+using Tensorflow;
+using Tensorflow.Keras.Losses;
+using static Tensorflow.Binding;
+using static Tensorflow.KerasApi;
+
+namespace TensorFlowNET.UnitTest.Keras
+{
+ [TestClass]
+ public class MeanSquaredLogarithmicError
+ {
+ //https://keras.io/api/losses/regression_losses/
+
+ NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } };
+ NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } };
+
+ [TestMethod]
+
+ public void _Default()
+ {
+ //>>> # Using 'auto'/'sum_over_batch_size' reduction type.
+ //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError()
+ //>>> msle(y_true, y_pred).numpy()
+ //0.240
+ var loss = keras.losses.MeanSquaredLogarithmicError();
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)(0.24022643f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _Sample_Weight()
+ {
+ //>>> # Calling with 'sample_weight'.
+ //>>> msle(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy()
+ //0.120
+ var loss = keras.losses.MeanSquaredLogarithmicError();
+ var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f });
+ Assert.AreEqual((NDArray)(0.12011322f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _SUM()
+ {
+ //>>> # Using 'sum' reduction type.
+ //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError(
+ //... reduction = tf.keras.losses.Reduction.SUM)
+ //>>> msle(y_true, y_pred).numpy()
+ //0.480
+ var loss = keras.losses.MeanSquaredLogarithmicError( reduction: ReductionV2.SUM);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)(0.48045287f), call.numpy());
+ }
+
+ [TestMethod]
+
+ public void _None()
+ {
+ //>>> # Using 'none' reduction type.
+ //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError(
+ //... reduction = tf.keras.losses.Reduction.NONE)
+ //>>> msle(y_true, y_pred).numpy()
+ //array([0.240, 0.240], dtype = float32)
+ var loss = keras.losses.MeanSquaredLogarithmicError(reduction: ReductionV2.NONE);
+ var call = loss.Call(y_true_float, y_pred_float);
+ Assert.AreEqual((NDArray)new float[] { 0.24022643f, 0.24022643f }, call.numpy());
+ }
+
+ }
+}