diff --git a/README.md b/README.md index 5ee82c93..2fd46c6b 100644 --- a/README.md +++ b/README.md @@ -56,30 +56,32 @@ PM> Install-Package SciSharp.TensorFlow.Redist-Windows-GPU Import TF.NET and Keras API in your project. -```cs +```csharp using static Tensorflow.Binding; using static Tensorflow.KerasApi; +using Tensorflow; +using NumSharp; ``` Linear Regression in `Eager` mode: -```c# +```csharp // Parameters var training_steps = 1000; var learning_rate = 0.01f; var display_step = 100; // Sample data -var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, +var X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, 7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f); -var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, +var Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); -var n_samples = train_X.shape[0]; +var n_samples = X.shape[0]; // We can set a fixed init value in order to demo var W = tf.Variable(-0.06f, name: "weight"); var b = tf.Variable(-0.73f, name: "bias"); -var optimizer = tf.optimizers.SGD(learning_rate); +var optimizer = keras.optimizers.SGD(learning_rate); // Run training for the given number of steps. foreach (var step in range(1, training_steps + 1)) @@ -112,46 +114,40 @@ Run this example in [Jupyter Notebook](https://github.com/SciSharp/SciSharpCube) Toy version of `ResNet` in `Keras` functional API: ```csharp +var layers = new LayersApi(); // input layer var inputs = keras.Input(shape: (32, 32, 3), name: "img"); - // convolutional layer var x = layers.Conv2D(32, 3, activation: "relu").Apply(inputs); x = layers.Conv2D(64, 3, activation: "relu").Apply(x); var block_1_output = layers.MaxPooling2D(3).Apply(x); - x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_1_output); x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x); -var block_2_output = layers.add(x, block_1_output); - +var block_2_output = layers.Add().Apply(new Tensors(x, block_1_output)); x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(block_2_output); x = layers.Conv2D(64, 3, activation: "relu", padding: "same").Apply(x); -var block_3_output = layers.add(x, block_2_output); - +var block_3_output = layers.Add().Apply(new Tensors(x, block_2_output)); x = layers.Conv2D(64, 3, activation: "relu").Apply(block_3_output); x = layers.GlobalAveragePooling2D().Apply(x); x = layers.Dense(256, activation: "relu").Apply(x); x = layers.Dropout(0.5f).Apply(x); - // output layer var outputs = layers.Dense(10).Apply(x); - // build keras model -model = keras.Model(inputs, outputs, name: "toy_resnet"); +var model = keras.Model(inputs, outputs, name: "toy_resnet"); model.summary(); - // compile keras model in tensorflow static graph model.compile(optimizer: keras.optimizers.RMSprop(1e-3f), - loss: keras.losses.CategoricalCrossentropy(from_logits: true), - metrics: new[] { "acc" }); - + loss: keras.losses.CategoricalCrossentropy(from_logits: true), + metrics: new[] { "acc" }); // prepare dataset var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data(); - +x_train = x_train / 255.0f; +y_train = np_utils.to_categorical(y_train, 10); // training -model.fit(x_train[new Slice(0, 1000)], y_train[new Slice(0, 1000)], - batch_size: 64, - epochs: 10, +model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], + batch_size: 64, + epochs: 10, validation_split: 0.2f); ``` @@ -260,4 +256,4 @@ WeChat Sponsor 微信打赏: TensorFlow.NET is a part of [SciSharp STACK](https://scisharp.github.io/SciSharp/)
- + \ No newline at end of file diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index b0d741f1..5c335b81 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -506,6 +506,27 @@ namespace Tensorflow } } + + public static Tensor where_v2(Tensor condition, object x = null, object y = null, string name = null) + { + if (x == null && y == null) + { + return tf_with(ops.name_scope(name, "Where", new { condition }), scope => + { + name = scope; + condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); + return gen_array_ops.where(condition: condition, name: name); + }); + } + else if (x != null && y != null) + { + return gen_array_ops.select_v2(condition, x, y, name); + } + else + { + throw new ValueError("x and y must both be non-None or both be None."); + } + } /// /// Returns the shape of a tensor. /// diff --git a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs index 50dde5c3..019d19be 100644 --- a/src/TensorFlowNET.Core/Operations/gen_array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_array_ops.cs @@ -423,6 +423,21 @@ namespace Tensorflow var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t = x, e = y }); return _op.outputs[0]; } + public static Tensor select_v2(Tensor condition, Tx x, Ty y, string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "SelectV2", name, + null, + condition, x, y); + + return results[0]; + } + + var _op = tf.OpDefLib._apply_op_helper("SelectV2", name, new { condition, t = x, e = y }); + return _op.outputs[0]; + } public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor[] shape, string name = null) { diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs index 67b5b614..ca74ea5f 100644 --- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs @@ -714,7 +714,23 @@ namespace Tensorflow return _op.outputs[0]; } + public static Tensor softplus(Tensor features, string name = null) + { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Softplus", name, + null, + features); + return results[0]; + } + + var _op = tf.OpDefLib._apply_op_helper("Softplus", name, args: new { features }); + + return _op.outputs[0]; + } + public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string name = null) => tf.Context.RunInAutoMode(() => tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, () @@ -1068,6 +1084,15 @@ namespace Tensorflow public static Tensor _abs(Tensor x, string name = null) { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Abs", name, + null, + x); + + return results[0]; + } var _op = tf.OpDefLib._apply_op_helper("Abs", name, args: new { x }); return _op.output; @@ -1202,6 +1227,15 @@ namespace Tensorflow /// public static Tensor rsqrt(Tensor x, string name = null) { + if (tf.Context.executing_eagerly()) + { + var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, + "Rsqrt", name, + null, + x); + + return results[0]; + } var _op = tf.OpDefLib._apply_op_helper("Rsqrt", name, new { x }); return _op.outputs[0]; diff --git a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs index 121a4728..7b008e4e 100644 --- a/src/TensorFlowNET.Core/Operations/nn_impl.py.cs +++ b/src/TensorFlowNET.Core/Operations/nn_impl.py.cs @@ -31,7 +31,7 @@ namespace Tensorflow /// public static Tensor l2_normalize(Tensor x, int axis = 0, - float epsilon = 1e-12f, + Tensor epsilon =null, string name = null) { return tf_with(ops.name_scope(name, "l2_normalize", new { x }), scope => @@ -39,7 +39,7 @@ namespace Tensorflow x = ops.convert_to_tensor(x, name: "x"); var sq = math_ops.square(x); var square_sum = math_ops.reduce_sum(sq, axis, keepdims: true); - var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)); + var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon == null ? tf.Variable(1e-12f) : epsilon)); return math_ops.multiply(x, x_inv_norm, name: name); }); } diff --git a/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs index bf5dbb64..c80b1a83 100644 --- a/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs +++ b/src/TensorFlowNET.Keras/Losses/CategoricalCrossentropy.cs @@ -9,18 +9,19 @@ namespace Tensorflow.Keras.Losses public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc { float label_smoothing; - - public CategoricalCrossentropy(bool from_logits = false, + public CategoricalCrossentropy( + bool from_logits = false, float label_smoothing = 0, - string reduction = ReductionV2.AUTO, - string name = "categorical_crossentropy") : - base(reduction: reduction, - name: name, - from_logits: from_logits) + string reduction = null, + string name = null) : + base(reduction: reduction, + name: name == null ? "categorical_crossentropy" : name, + from_logits: from_logits) { this.label_smoothing = label_smoothing; } + public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) { // Try to adjust the shape so that rank of labels = rank of logits - 1. diff --git a/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs b/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs new file mode 100644 index 00000000..57debbc9 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/CosineSimilarity.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Losses +{ + public class CosineSimilarity : LossFunctionWrapper, ILossFunc + { + protected int axis=-1; + public CosineSimilarity( + string reduction = null, + int axis=-1, + string name = null) : + base(reduction: reduction, name: name == null ? "cosine_similarity" : name) + { + this.axis = axis; + } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis : this.axis); + Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis); + return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis : this.axis); + } + } +} diff --git a/src/TensorFlowNET.Keras/Losses/Huber.cs b/src/TensorFlowNET.Keras/Losses/Huber.cs new file mode 100644 index 00000000..6098dee3 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/Huber.cs @@ -0,0 +1,36 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Losses +{ + public class Huber : LossFunctionWrapper, ILossFunc + { + protected Tensor delta = tf.Variable(1.0) ; + public Huber ( + string reduction = null, + Tensor delta = null, + string name = null) : + base(reduction: reduction, name: name == null ? "huber" : name) + { + this.delta = delta==null? this.delta: delta; + + } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT); + Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT); + Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT); + Tensor error = math_ops.subtract(y_pred_cast, y_true_cast); + Tensor abs_error = math_ops.abs(error); + Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype); + return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, + half * math_ops.pow(error, 2), + half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), + axis : -1); + } + } +} diff --git a/src/TensorFlowNET.Keras/Losses/ILossFunc.cs b/src/TensorFlowNET.Keras/Losses/ILossFunc.cs index 45c39dd2..8bc226df 100644 --- a/src/TensorFlowNET.Keras/Losses/ILossFunc.cs +++ b/src/TensorFlowNET.Keras/Losses/ILossFunc.cs @@ -2,7 +2,8 @@ { public interface ILossFunc { - string Reduction { get; } - Tensor Call(Tensor y_true, Tensor y_pred); + public string Reduction { get; } + public string Name { get; } + Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null); } } diff --git a/src/TensorFlowNET.Keras/Losses/LogCosh.cs b/src/TensorFlowNET.Keras/Losses/LogCosh.cs new file mode 100644 index 00000000..6db10bc8 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/LogCosh.cs @@ -0,0 +1,28 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Operations; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Losses +{ + public class LogCosh : LossFunctionWrapper, ILossFunc + { + public LogCosh( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "huber" : name){ } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + Tensor x = y_pred_dispatch - y_true_cast; + + return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),axis: -1); + + + } + } +} diff --git a/src/TensorFlowNET.Keras/Losses/Loss.cs b/src/TensorFlowNET.Keras/Losses/Loss.cs index 54b2b249..fe017ac4 100644 --- a/src/TensorFlowNET.Keras/Losses/Loss.cs +++ b/src/TensorFlowNET.Keras/Losses/Loss.cs @@ -15,12 +15,12 @@ namespace Tensorflow.Keras.Losses string _name_scope; public string Reduction => reduction; - + public string Name => name; public Loss(string reduction = ReductionV2.AUTO, string name = null, bool from_logits = false) { - this.reduction = reduction; + this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction; this.name = name; this.from_logits = from_logits; _allow_sum_over_batch_size = false; @@ -31,10 +31,10 @@ namespace Tensorflow.Keras.Losses throw new NotImplementedException(""); } - public Tensor Call(Tensor y_true, Tensor y_pred) + public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) { var losses = Apply(y_true, y_pred, from_logits: from_logits); - return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); + return losses_utils.compute_weighted_loss(losses, reduction: this.reduction , sample_weight: sample_weight); } void _set_name_scope() diff --git a/src/TensorFlowNET.Keras/Losses/LossesApi.cs b/src/TensorFlowNET.Keras/Losses/LossesApi.cs index 3e66b395..71cffebb 100644 --- a/src/TensorFlowNET.Keras/Losses/LossesApi.cs +++ b/src/TensorFlowNET.Keras/Losses/LossesApi.cs @@ -2,10 +2,31 @@ { public class LossesApi { - public ILossFunc SparseCategoricalCrossentropy(bool from_logits = false) - => new SparseCategoricalCrossentropy(from_logits: from_logits); + public ILossFunc SparseCategoricalCrossentropy(string reduction = null, string name = null,bool from_logits = false) + => new SparseCategoricalCrossentropy(reduction: reduction, name: name,from_logits: from_logits); + + public ILossFunc CategoricalCrossentropy(string reduction = null, string name = null,bool from_logits = false) + => new CategoricalCrossentropy(reduction: reduction, name: name,from_logits: from_logits); + + public ILossFunc MeanSquaredError(string reduction = null, string name = null) + => new MeanSquaredError(reduction: reduction, name:name); + public ILossFunc MeanSquaredLogarithmicError(string reduction = null, string name = null) + => new MeanSquaredLogarithmicError(reduction: reduction, name: name); + + public ILossFunc MeanAbsolutePercentageError(string reduction = null, string name = null) + => new MeanAbsolutePercentageError(reduction: reduction, name: name); + + public ILossFunc MeanAbsoluteError(string reduction = null, string name = null) + => new MeanAbsoluteError(reduction: reduction, name: name); + + public ILossFunc CosineSimilarity(string reduction = null, string name = null,int axis=-1) + => new CosineSimilarity(reduction: reduction, name: name, axis: axis); + + public ILossFunc Huber(string reduction = null, string name = null, Tensor delta=null) + => new Huber(reduction: reduction, name: name, delta: delta); + + public ILossFunc LogCosh(string reduction = null, string name = null) + => new LogCosh(reduction: reduction, name: name); - public ILossFunc CategoricalCrossentropy(bool from_logits = false) - => new CategoricalCrossentropy(from_logits: from_logits); } } diff --git a/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs b/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs new file mode 100644 index 00000000..5d0f83d4 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/MeanAbsoluteError.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Losses +{ + public class MeanAbsoluteError : LossFunctionWrapper, ILossFunc + { + public MeanAbsoluteError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), axis: -1); + } + } +} diff --git a/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs new file mode 100644 index 00000000..74c95b4a --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/MeanAbsolutePercentageError.cs @@ -0,0 +1,24 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Losses +{ + public class MeanAbsolutePercentageError : LossFunctionWrapper, ILossFunc + { + public MeanAbsolutePercentageError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype)); + return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) *gen_math_ops.mean(diff, axis: -1); + } + } +} diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs new file mode 100644 index 00000000..24ef1043 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs @@ -0,0 +1,23 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Losses +{ + public class MeanSquaredError : LossFunctionWrapper, ILossFunc + { + public MeanSquaredError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name==null? "mean_squared_error" : name){ } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); + } + } +} diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs new file mode 100644 index 00000000..22b5a6ff --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredLogarithmicError.cs @@ -0,0 +1,33 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Losses +{ + public class MeanSquaredLogarithmicError : LossFunctionWrapper, ILossFunc + { + public MeanSquaredLogarithmicError( + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "mean_squared_logarithmic_error" : name){ } + + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + Tensor first_log=null, second_log=null; + if (y_pred_dispatch.dtype == TF_DataType.TF_DOUBLE) { + first_log = math_ops.log(gen_math_ops.maximum(y_pred_dispatch, 1e-7) + 1.0); + second_log = math_ops.log(gen_math_ops.maximum(y_true_cast, 1e-7) + 1.0); + } + else { + first_log = math_ops.log(gen_math_ops.maximum(y_pred_dispatch, 1e-7f) + 1.0f); + second_log = math_ops.log(gen_math_ops.maximum(y_true_cast, 1e-7f) + 1.0f); + } + return gen_math_ops.mean(gen_math_ops.squared_difference(first_log, second_log), axis: -1); + } + } +} diff --git a/src/TensorFlowNET.Keras/Losses/ReductionV2.cs b/src/TensorFlowNET.Keras/Losses/ReductionV2.cs index afe2006d..4b6cbbfd 100644 --- a/src/TensorFlowNET.Keras/Losses/ReductionV2.cs +++ b/src/TensorFlowNET.Keras/Losses/ReductionV2.cs @@ -4,6 +4,7 @@ { public const string NONE = "none"; public const string AUTO = "auto"; + public const string SUM = "sum"; public const string SUM_OVER_BATCH_SIZE = "sum_over_batch_size"; public const string WEIGHTED_MEAN = "weighted_mean"; } diff --git a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs index fe14e887..2cf24fc3 100644 --- a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs +++ b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs @@ -4,14 +4,11 @@ namespace Tensorflow.Keras.Losses { public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc { - public SparseCategoricalCrossentropy(bool from_logits = false, - string reduction = ReductionV2.AUTO, - string name = "sparse_categorical_crossentropy") : - base(reduction: reduction, - name: name) - { - - } + public SparseCategoricalCrossentropy( + bool from_logits = false, + string reduction = null, + string name = null) : + base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ } public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) { diff --git a/test/TensorFlowNET.UnitTest/Keras/CosineSimilarity.Test.cs b/test/TensorFlowNET.UnitTest/Keras/CosineSimilarity.Test.cs new file mode 100644 index 00000000..70e07264 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/CosineSimilarity.Test.cs @@ -0,0 +1,76 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using Tensorflow.Keras.Losses; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.UnitTest.Keras +{ + [TestClass] + public class CosineSimilarity + { + //https://keras.io/api/losses/regression_losses/ + + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 1.0f, 1.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 0.0f }, { 1.0f, 1.0f } }; + + [TestMethod] + + public void _Default() + { + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1) + //>>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]] + //>>> # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]] + //>>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] + //>>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) + //>>> # = -((0. + 0.) + (0.5 + 0.5)) / 2 + //-0.5 + var loss = keras.losses.CosineSimilarity(axis : 1); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(-0.49999997f), call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> cosine_loss(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy() + //- 0.0999 + var loss = keras.losses.CosineSimilarity(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f }); + Assert.AreEqual((NDArray) (- 0.099999994f), call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1, + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> cosine_loss(y_true, y_pred).numpy() + //- 0.999 + var loss = keras.losses.CosineSimilarity(axis: 1,reduction : ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(-0.99999994f), call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1, + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> cosine_loss(y_true, y_pred).numpy() + //array([-0., -0.999], dtype = float32) + var loss = keras.losses.CosineSimilarity(axis :1, reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { -0f, -0.99999994f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.UnitTest/Keras/Huber.Test.cs b/test/TensorFlowNET.UnitTest/Keras/Huber.Test.cs new file mode 100644 index 00000000..cbc16eab --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/Huber.Test.cs @@ -0,0 +1,72 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using Tensorflow.Keras.Losses; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.UnitTest.Keras +{ + [TestClass] + public class Huber + { + //https://keras.io/api/losses/regression_losses/#meansquarederror-class + + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 0.6f, 0.4f }, { 0.4f, 0.6f } }; + + [TestMethod] + + public void _Default() + { + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> h = tf.keras.losses.Huber() + //>>> h(y_true, y_pred).numpy() + //0.155 + var loss = keras.losses.Huber(); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)0.155f, call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> h(y_true, y_pred, sample_weight =[1, 0]).numpy() + //0.09 + var loss = keras.losses.Huber(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.1f, 0.0f }); + Assert.AreEqual((NDArray)0.009000001f, call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> h = tf.keras.losses.Huber( + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> h(y_true, y_pred).numpy() + //0.31 + var loss = keras.losses.Huber(reduction : ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)0.31f, call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> h = tf.keras.losses.Huber( + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> h(y_true, y_pred).numpy() + //array([0.18, 0.13], dtype = float32) + var loss = keras.losses.Huber(reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { 0.18f, 0.13000001f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.UnitTest/Keras/LogCosh.Test.cs b/test/TensorFlowNET.UnitTest/Keras/LogCosh.Test.cs new file mode 100644 index 00000000..48d2d862 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/LogCosh.Test.cs @@ -0,0 +1,72 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using Tensorflow.Keras.Losses; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.UnitTest.Keras +{ + [TestClass] + public class LogCosh + { + //https://keras.io/api/losses/regression_losses/#meansquarederror-class + + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 0.0f, 0.0f } }; + + [TestMethod] + + public void _Default() + { + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> l = tf.keras.losses.LogCosh() + //>>> l(y_true, y_pred).numpy() + //0.108 + var loss = keras.losses.LogCosh(); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)0.1084452f, call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> l(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy() + //0.087 + var loss = keras.losses.LogCosh(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f }); + Assert.AreEqual((NDArray)0.08675616f, call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> l = tf.keras.losses.LogCosh( + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> l(y_true, y_pred).numpy() + //0.217 + var loss = keras.losses.LogCosh(reduction : ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)0.2168904f, call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> l = tf.keras.losses.LogCosh( + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> l(y_true, y_pred).numpy() + //array([0.217, 0.], dtype = float32) + var loss = keras.losses.LogCosh(reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { 0.2168904f, 0.0f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.UnitTest/Keras/MeanAbsoluteError.Test.cs b/test/TensorFlowNET.UnitTest/Keras/MeanAbsoluteError.Test.cs new file mode 100644 index 00000000..2b7a2504 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/MeanAbsoluteError.Test.cs @@ -0,0 +1,73 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using Tensorflow.Keras.Losses; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.UnitTest.Keras +{ + [TestClass] + public class MeanAbsoluteError + { + //https://keras.io/api/losses/regression_losses/ + + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; + + [TestMethod] + + public void _Default() + { + + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> mae = tf.keras.losses.MeanAbsoluteError() + //>>> mae(y_true, y_pred).numpy() + //0.5 + var loss = keras.losses.MeanAbsoluteError(); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(0.5f), call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> mae(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy() + //0.25 + var loss = keras.losses.MeanAbsoluteError(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f }); + Assert.AreEqual((NDArray)(0.25f), call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> mae = tf.keras.losses.MeanAbsoluteError( + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> mae(y_true, y_pred).numpy() + //1.0 + var loss = keras.losses.MeanAbsoluteError( reduction: ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(1.0f), call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> mae = tf.keras.losses.MeanAbsoluteError( + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> mae(y_true, y_pred).numpy() + //array([0.5, 0.5], dtype = float32) + var loss = keras.losses.MeanAbsoluteError(reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { 0.5f, 0.5f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.UnitTest/Keras/MeanAbsolutePercentageError.Test.cs b/test/TensorFlowNET.UnitTest/Keras/MeanAbsolutePercentageError.Test.cs new file mode 100644 index 00000000..97b43503 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/MeanAbsolutePercentageError.Test.cs @@ -0,0 +1,72 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using Tensorflow.Keras.Losses; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.UnitTest.Keras +{ + [TestClass] + public class MeanAbsolutePercentageError + { + //https://keras.io/api/losses/regression_losses/ + + NDArray y_true_float = new float[,] { { 2.0f, 1.0f }, { 2.0f, 3.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; + + [TestMethod] + + public void _Default() + { + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> mape = tf.keras.losses.MeanAbsolutePercentageError() + //>>> mape(y_true, y_pred).numpy() + //50. + var loss = keras.losses.MeanAbsolutePercentageError(); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(50f), call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> mape(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy() + //20. + var loss = keras.losses.MeanAbsolutePercentageError(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f }); + Assert.AreEqual((NDArray)(20f), call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> mape = tf.keras.losses.MeanAbsolutePercentageError( + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> mape(y_true, y_pred).numpy() + //100. + var loss = keras.losses.MeanAbsolutePercentageError( reduction: ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(100f), call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> mape = tf.keras.losses.MeanAbsolutePercentageError( + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> mape(y_true, y_pred).numpy() + //array([25., 75.], dtype = float32) + var loss = keras.losses.MeanAbsolutePercentageError(reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { 25f, 75f }, call.numpy()); + } + + } +} diff --git a/test/TensorFlowNET.UnitTest/Keras/MeanSquaredError.Test.cs b/test/TensorFlowNET.UnitTest/Keras/MeanSquaredError.Test.cs new file mode 100644 index 00000000..f1c782f8 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/MeanSquaredError.Test.cs @@ -0,0 +1,65 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using Tensorflow.Keras.Losses; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.UnitTest.Keras +{ + [TestClass] + public class MeanSquaredErrorTest + { + //https://keras.io/api/losses/regression_losses/#meansquarederror-class + + private NDArray y_true = new double[,] { { 0.0, 1.0 }, { 0.0, 0.0 } }; + private NDArray y_pred = new double[,] { { 1.0, 1.0 }, { 1.0, 0.0 } }; + + [TestMethod] + + public void Mse_Double() + { + var mse = keras.losses.MeanSquaredError(); + var call = mse.Call(y_true, y_pred); + Assert.AreEqual((NDArray)0.5, call.numpy()) ; + } + + [TestMethod] + + public void Mse_Float() + { + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; + + var mse = keras.losses.MeanSquaredError(); + var call = mse.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)0.5, call.numpy()); + } + + [TestMethod] + + public void Mse_Sample_Weight() + { + var mse = keras.losses.MeanSquaredError(); + var call = mse.Call(y_true, y_pred, sample_weight: (NDArray)new double[] { 0.7, 0.3 }); + Assert.AreEqual((NDArray)0.25, call.numpy()); + } + + [TestMethod] + public void Mse_Reduction_SUM() + { + var mse = keras.losses.MeanSquaredError(reduction: Reduction.SUM); + var call = mse.Call(y_true, y_pred); + Assert.AreEqual((NDArray)1.0, call.numpy()); + } + + [TestMethod] + + public void Mse_Reduction_NONE() + { + var mse = keras.losses.MeanSquaredError(reduction: Reduction.NONE); + var call = mse.Call(y_true, y_pred); + Assert.AreEqual((NDArray)new double[] { 0.5, 0.5 }, call.numpy()); + } + } +} diff --git a/test/TensorFlowNET.UnitTest/Keras/MeanSquaredLogarithmicError.Test.cs b/test/TensorFlowNET.UnitTest/Keras/MeanSquaredLogarithmicError.Test.cs new file mode 100644 index 00000000..28499143 --- /dev/null +++ b/test/TensorFlowNET.UnitTest/Keras/MeanSquaredLogarithmicError.Test.cs @@ -0,0 +1,72 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; +using Tensorflow; +using Tensorflow.Keras.Losses; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace TensorFlowNET.UnitTest.Keras +{ + [TestClass] + public class MeanSquaredLogarithmicError + { + //https://keras.io/api/losses/regression_losses/ + + NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; + NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; + + [TestMethod] + + public void _Default() + { + //>>> # Using 'auto'/'sum_over_batch_size' reduction type. + //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError() + //>>> msle(y_true, y_pred).numpy() + //0.240 + var loss = keras.losses.MeanSquaredLogarithmicError(); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(0.24022643f), call.numpy()); + } + + [TestMethod] + + public void _Sample_Weight() + { + //>>> # Calling with 'sample_weight'. + //>>> msle(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy() + //0.120 + var loss = keras.losses.MeanSquaredLogarithmicError(); + var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f }); + Assert.AreEqual((NDArray)(0.12011322f), call.numpy()); + } + + [TestMethod] + + public void _SUM() + { + //>>> # Using 'sum' reduction type. + //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError( + //... reduction = tf.keras.losses.Reduction.SUM) + //>>> msle(y_true, y_pred).numpy() + //0.480 + var loss = keras.losses.MeanSquaredLogarithmicError( reduction: ReductionV2.SUM); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)(0.48045287f), call.numpy()); + } + + [TestMethod] + + public void _None() + { + //>>> # Using 'none' reduction type. + //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError( + //... reduction = tf.keras.losses.Reduction.NONE) + //>>> msle(y_true, y_pred).numpy() + //array([0.240, 0.240], dtype = float32) + var loss = keras.losses.MeanSquaredLogarithmicError(reduction: ReductionV2.NONE); + var call = loss.Call(y_true_float, y_pred_float); + Assert.AreEqual((NDArray)new float[] { 0.24022643f, 0.24022643f }, call.numpy()); + } + + } +}