| @@ -506,6 +506,27 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public static Tensor where_v2(Tensor condition, object x = null, object y = null, string name = null) | |||||
| { | |||||
| if (x == null && y == null) | |||||
| { | |||||
| return tf_with(ops.name_scope(name, "Where", new { condition }), scope => | |||||
| { | |||||
| name = scope; | |||||
| condition = ops.convert_to_tensor(condition, preferred_dtype: dtypes.@bool, name: "condition"); | |||||
| return gen_array_ops.where(condition: condition, name: name); | |||||
| }); | |||||
| } | |||||
| else if (x != null && y != null) | |||||
| { | |||||
| return gen_array_ops.select_v2(condition, x, y, name); | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new ValueError("x and y must both be non-None or both be None."); | |||||
| } | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Returns the shape of a tensor. | /// Returns the shape of a tensor. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -423,6 +423,21 @@ namespace Tensorflow | |||||
| var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t = x, e = y }); | var _op = tf.OpDefLib._apply_op_helper("Select", name, new { condition, t = x, e = y }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor select_v2<Tx, Ty>(Tensor condition, Tx x, Ty y, string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "SelectV2", name, | |||||
| null, | |||||
| condition, x, y); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("SelectV2", name, new { condition, t = x, e = y }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor[] shape, string name = null) | public static Tensor scatter_nd(Tensor indices, Tensor updates, Tensor[] shape, string name = null) | ||||
| { | { | ||||
| @@ -714,7 +714,23 @@ namespace Tensorflow | |||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| } | } | ||||
| public static Tensor softplus(Tensor features, string name = null) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "Softplus", name, | |||||
| null, | |||||
| features); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("Softplus", name, args: new { features }); | |||||
| return _op.outputs[0]; | |||||
| } | |||||
| public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string name = null) | public static Tensor cast(Tensor x, TF_DataType DstT, bool Truncate = false, string name = null) | ||||
| => tf.Context.RunInAutoMode(() | => tf.Context.RunInAutoMode(() | ||||
| => tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, () | => tf.OpDefLib._apply_op_helper("Cast", name, args: new { x, DstT, Truncate }).output, () | ||||
| @@ -1068,6 +1084,15 @@ namespace Tensorflow | |||||
| public static Tensor _abs(Tensor x, string name = null) | public static Tensor _abs(Tensor x, string name = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "Abs", name, | |||||
| null, | |||||
| x); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("Abs", name, args: new { x }); | var _op = tf.OpDefLib._apply_op_helper("Abs", name, args: new { x }); | ||||
| return _op.output; | return _op.output; | ||||
| @@ -1202,6 +1227,15 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor rsqrt(Tensor x, string name = null) | public static Tensor rsqrt(Tensor x, string name = null) | ||||
| { | { | ||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName, | |||||
| "Rsqrt", name, | |||||
| null, | |||||
| x); | |||||
| return results[0]; | |||||
| } | |||||
| var _op = tf.OpDefLib._apply_op_helper("Rsqrt", name, new { x }); | var _op = tf.OpDefLib._apply_op_helper("Rsqrt", name, new { x }); | ||||
| return _op.outputs[0]; | return _op.outputs[0]; | ||||
| @@ -31,7 +31,7 @@ namespace Tensorflow | |||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public static Tensor l2_normalize(Tensor x, | public static Tensor l2_normalize(Tensor x, | ||||
| int axis = 0, | int axis = 0, | ||||
| float epsilon = 1e-12f, | |||||
| Tensor epsilon =null, | |||||
| string name = null) | string name = null) | ||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "l2_normalize", new { x }), scope => | return tf_with(ops.name_scope(name, "l2_normalize", new { x }), scope => | ||||
| @@ -39,7 +39,7 @@ namespace Tensorflow | |||||
| x = ops.convert_to_tensor(x, name: "x"); | x = ops.convert_to_tensor(x, name: "x"); | ||||
| var sq = math_ops.square(x); | var sq = math_ops.square(x); | ||||
| var square_sum = math_ops.reduce_sum(sq, axis, keepdims: true); | var square_sum = math_ops.reduce_sum(sq, axis, keepdims: true); | ||||
| var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon)); | |||||
| var x_inv_norm = math_ops.rsqrt(math_ops.maximum(square_sum, epsilon == null ? tf.Variable(1e-12f) : epsilon)); | |||||
| return math_ops.multiply(x, x_inv_norm, name: name); | return math_ops.multiply(x, x_inv_norm, name: name); | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -9,18 +9,19 @@ namespace Tensorflow.Keras.Losses | |||||
| public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc | public class CategoricalCrossentropy : LossFunctionWrapper, ILossFunc | ||||
| { | { | ||||
| float label_smoothing; | float label_smoothing; | ||||
| public CategoricalCrossentropy(bool from_logits = false, | |||||
| public CategoricalCrossentropy( | |||||
| bool from_logits = false, | |||||
| float label_smoothing = 0, | float label_smoothing = 0, | ||||
| string reduction = ReductionV2.AUTO, | |||||
| string name = "categorical_crossentropy") : | |||||
| base(reduction: reduction, | |||||
| name: name, | |||||
| from_logits: from_logits) | |||||
| string reduction = null, | |||||
| string name = null) : | |||||
| base(reduction: reduction, | |||||
| name: name == null ? "categorical_crossentropy" : name, | |||||
| from_logits: from_logits) | |||||
| { | { | ||||
| this.label_smoothing = label_smoothing; | this.label_smoothing = label_smoothing; | ||||
| } | } | ||||
| public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) | public override Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) | ||||
| { | { | ||||
| // Try to adjust the shape so that rank of labels = rank of logits - 1. | // Try to adjust the shape so that rank of labels = rank of logits - 1. | ||||
| @@ -0,0 +1,28 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Losses | |||||
| { | |||||
| public class CosineSimilarity : LossFunctionWrapper, ILossFunc | |||||
| { | |||||
| protected int axis=-1; | |||||
| public CosineSimilarity( | |||||
| string reduction = null, | |||||
| int axis=-1, | |||||
| string name = null) : | |||||
| base(reduction: reduction, name: name == null ? "cosine_similarity" : name) | |||||
| { | |||||
| this.axis = axis; | |||||
| } | |||||
| public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) | |||||
| { | |||||
| Tensor y_true_normalize = nn_impl.l2_normalize(y_true, axis : this.axis); | |||||
| Tensor y_pred_normalize = nn_impl.l2_normalize(y_pred, axis: this.axis); | |||||
| return -math_ops.reduce_sum(y_true_normalize * y_pred_normalize, axis : this.axis); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,36 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Losses | |||||
| { | |||||
| public class Huber : LossFunctionWrapper, ILossFunc | |||||
| { | |||||
| protected Tensor delta = tf.Variable(1.0) ; | |||||
| public Huber ( | |||||
| string reduction = null, | |||||
| Tensor delta = null, | |||||
| string name = null) : | |||||
| base(reduction: reduction, name: name == null ? "huber" : name) | |||||
| { | |||||
| this.delta = delta==null? this.delta: delta; | |||||
| } | |||||
| public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) | |||||
| { | |||||
| Tensor y_pred_cast = math_ops.cast(y_pred, dtype: TF_DataType.TF_FLOAT); | |||||
| Tensor y_true_cast = math_ops.cast(y_true, dtype: TF_DataType.TF_FLOAT); | |||||
| Tensor delta = math_ops.cast(this.delta, dtype: TF_DataType.TF_FLOAT); | |||||
| Tensor error = math_ops.subtract(y_pred_cast, y_true_cast); | |||||
| Tensor abs_error = math_ops.abs(error); | |||||
| Tensor half = ops.convert_to_tensor(0.5, dtype: abs_error.dtype); | |||||
| return gen_math_ops.mean(array_ops.where_v2(abs_error <= delta, | |||||
| half * math_ops.pow(error, 2), | |||||
| half * math_ops.pow(delta, 2) + delta * (abs_error - delta)), | |||||
| axis : -1); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -2,7 +2,8 @@ | |||||
| { | { | ||||
| public interface ILossFunc | public interface ILossFunc | ||||
| { | { | ||||
| string Reduction { get; } | |||||
| public string Reduction { get; } | |||||
| public string Name { get; } | |||||
| Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null); | Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null); | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,28 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Operations; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Losses | |||||
| { | |||||
| public class LogCosh : LossFunctionWrapper, ILossFunc | |||||
| { | |||||
| public LogCosh( | |||||
| string reduction = null, | |||||
| string name = null) : | |||||
| base(reduction: reduction, name: name == null ? "huber" : name){ } | |||||
| public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) | |||||
| { | |||||
| Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); | |||||
| Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||||
| Tensor x = y_pred_dispatch - y_true_cast; | |||||
| return gen_math_ops.mean(x + gen_math_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype),axis: -1); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -15,12 +15,12 @@ namespace Tensorflow.Keras.Losses | |||||
| string _name_scope; | string _name_scope; | ||||
| public string Reduction => reduction; | public string Reduction => reduction; | ||||
| public string Name => name; | |||||
| public Loss(string reduction = ReductionV2.AUTO, | public Loss(string reduction = ReductionV2.AUTO, | ||||
| string name = null, | string name = null, | ||||
| bool from_logits = false) | bool from_logits = false) | ||||
| { | { | ||||
| this.reduction = reduction; | |||||
| this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction; | |||||
| this.name = name; | this.name = name; | ||||
| this.from_logits = from_logits; | this.from_logits = from_logits; | ||||
| _allow_sum_over_batch_size = false; | _allow_sum_over_batch_size = false; | ||||
| @@ -34,8 +34,7 @@ namespace Tensorflow.Keras.Losses | |||||
| public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | ||||
| { | { | ||||
| var losses = Apply(y_true, y_pred, from_logits: from_logits); | var losses = Apply(y_true, y_pred, from_logits: from_logits); | ||||
| return losses_utils.compute_weighted_loss(losses, reduction: this.reduction == null?ReductionV2.SUM_OVER_BATCH_SIZE : this.reduction, sample_weight: sample_weight); | |||||
| return losses_utils.compute_weighted_loss(losses, reduction: this.reduction , sample_weight: sample_weight); | |||||
| } | } | ||||
| void _set_name_scope() | void _set_name_scope() | ||||
| @@ -2,13 +2,31 @@ | |||||
| { | { | ||||
| public class LossesApi | public class LossesApi | ||||
| { | { | ||||
| public ILossFunc SparseCategoricalCrossentropy(bool from_logits = false) | |||||
| => new SparseCategoricalCrossentropy(from_logits: from_logits); | |||||
| public ILossFunc SparseCategoricalCrossentropy(string reduction = null, string name = null,bool from_logits = false) | |||||
| => new SparseCategoricalCrossentropy(reduction: reduction, name: name,from_logits: from_logits); | |||||
| public ILossFunc CategoricalCrossentropy(bool from_logits = false) | |||||
| => new CategoricalCrossentropy(from_logits: from_logits); | |||||
| public ILossFunc CategoricalCrossentropy(string reduction = null, string name = null,bool from_logits = false) | |||||
| => new CategoricalCrossentropy(reduction: reduction, name: name,from_logits: from_logits); | |||||
| public ILossFunc MeanSquaredError(string reduction = null) | |||||
| => new MeanSquaredError(reduction: reduction); | |||||
| public ILossFunc MeanSquaredError(string reduction = null, string name = null) | |||||
| => new MeanSquaredError(reduction: reduction, name:name); | |||||
| public ILossFunc MeanSquaredLogarithmicError(string reduction = null, string name = null) | |||||
| => new MeanSquaredLogarithmicError(reduction: reduction, name: name); | |||||
| public ILossFunc MeanAbsolutePercentageError(string reduction = null, string name = null) | |||||
| => new MeanAbsolutePercentageError(reduction: reduction, name: name); | |||||
| public ILossFunc MeanAbsoluteError(string reduction = null, string name = null) | |||||
| => new MeanAbsoluteError(reduction: reduction, name: name); | |||||
| public ILossFunc CosineSimilarity(string reduction = null, string name = null,int axis=-1) | |||||
| => new CosineSimilarity(reduction: reduction, name: name, axis: axis); | |||||
| public ILossFunc Huber(string reduction = null, string name = null, Tensor delta=null) | |||||
| => new Huber(reduction: reduction, name: name, delta: delta); | |||||
| public ILossFunc LogCosh(string reduction = null, string name = null) | |||||
| => new LogCosh(reduction: reduction, name: name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,23 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Losses | |||||
| { | |||||
| public class MeanAbsoluteError : LossFunctionWrapper, ILossFunc | |||||
| { | |||||
| public MeanAbsoluteError( | |||||
| string reduction = null, | |||||
| string name = null) : | |||||
| base(reduction: reduction, name: name == null ? "mean_absolute_error" : name){ } | |||||
| public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) | |||||
| { | |||||
| Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); | |||||
| Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||||
| return gen_math_ops.mean(math_ops.abs(y_pred_dispatch - y_true_cast), axis: -1); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,24 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Losses | |||||
| { | |||||
| public class MeanAbsolutePercentageError : LossFunctionWrapper, ILossFunc | |||||
| { | |||||
| public MeanAbsolutePercentageError( | |||||
| string reduction = null, | |||||
| string name = null) : | |||||
| base(reduction: reduction, name: name == null ? "mean_absolute_percentage_error" : name){ } | |||||
| public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) | |||||
| { | |||||
| Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); | |||||
| Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||||
| Tensor diff = math_ops.abs(y_true_cast - y_pred_dispatch) / gen_math_ops.maximum(math_ops.abs(y_true_cast), gen_math_ops.cast(tf.constant(1e-7), y_pred_dispatch.dtype)); | |||||
| return gen_math_ops.cast(tf.constant(100), y_pred_dispatch.dtype) *gen_math_ops.mean(diff, axis: -1); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -9,12 +9,9 @@ namespace Tensorflow.Keras.Losses | |||||
| public class MeanSquaredError : LossFunctionWrapper, ILossFunc | public class MeanSquaredError : LossFunctionWrapper, ILossFunc | ||||
| { | { | ||||
| public MeanSquaredError( | public MeanSquaredError( | ||||
| string reduction = ReductionV2.AUTO, | |||||
| string name = "mean_squared_error") : | |||||
| base(reduction: reduction, | |||||
| name: name) | |||||
| { | |||||
| } | |||||
| string reduction = null, | |||||
| string name = null) : | |||||
| base(reduction: reduction, name: name==null? "mean_squared_error" : name){ } | |||||
| public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) | public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) | ||||
| { | { | ||||
| @@ -0,0 +1,33 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Losses | |||||
| { | |||||
| public class MeanSquaredLogarithmicError : LossFunctionWrapper, ILossFunc | |||||
| { | |||||
| public MeanSquaredLogarithmicError( | |||||
| string reduction = null, | |||||
| string name = null) : | |||||
| base(reduction: reduction, name: name == null ? "mean_squared_logarithmic_error" : name){ } | |||||
| public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) | |||||
| { | |||||
| Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); | |||||
| Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||||
| Tensor first_log=null, second_log=null; | |||||
| if (y_pred_dispatch.dtype == TF_DataType.TF_DOUBLE) { | |||||
| first_log = math_ops.log(gen_math_ops.maximum(y_pred_dispatch, 1e-7) + 1.0); | |||||
| second_log = math_ops.log(gen_math_ops.maximum(y_true_cast, 1e-7) + 1.0); | |||||
| } | |||||
| else { | |||||
| first_log = math_ops.log(gen_math_ops.maximum(y_pred_dispatch, 1e-7f) + 1.0f); | |||||
| second_log = math_ops.log(gen_math_ops.maximum(y_true_cast, 1e-7f) + 1.0f); | |||||
| } | |||||
| return gen_math_ops.mean(gen_math_ops.squared_difference(first_log, second_log), axis: -1); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -4,6 +4,7 @@ | |||||
| { | { | ||||
| public const string NONE = "none"; | public const string NONE = "none"; | ||||
| public const string AUTO = "auto"; | public const string AUTO = "auto"; | ||||
| public const string SUM = "sum"; | |||||
| public const string SUM_OVER_BATCH_SIZE = "sum_over_batch_size"; | public const string SUM_OVER_BATCH_SIZE = "sum_over_batch_size"; | ||||
| public const string WEIGHTED_MEAN = "weighted_mean"; | public const string WEIGHTED_MEAN = "weighted_mean"; | ||||
| } | } | ||||
| @@ -4,14 +4,11 @@ namespace Tensorflow.Keras.Losses | |||||
| { | { | ||||
| public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc | public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc | ||||
| { | { | ||||
| public SparseCategoricalCrossentropy(bool from_logits = false, | |||||
| string reduction = ReductionV2.AUTO, | |||||
| string name = "sparse_categorical_crossentropy") : | |||||
| base(reduction: reduction, | |||||
| name: name) | |||||
| { | |||||
| } | |||||
| public SparseCategoricalCrossentropy( | |||||
| bool from_logits = false, | |||||
| string reduction = null, | |||||
| string name = null) : | |||||
| base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ } | |||||
| public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) | public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) | ||||
| { | { | ||||
| @@ -0,0 +1,76 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.UnitTest.Keras | |||||
| { | |||||
| [TestClass] | |||||
| public class CosineSimilarity | |||||
| { | |||||
| //https://keras.io/api/losses/regression_losses/ | |||||
| NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 1.0f, 1.0f } }; | |||||
| NDArray y_pred_float = new float[,] { { 1.0f, 0.0f }, { 1.0f, 1.0f } }; | |||||
| [TestMethod] | |||||
| public void _Default() | |||||
| { | |||||
| //>>> # Using 'auto'/'sum_over_batch_size' reduction type. | |||||
| //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1) | |||||
| //>>> # l2_norm(y_true) = [[0., 1.], [1./1.414], 1./1.414]]] | |||||
| //>>> # l2_norm(y_pred) = [[1., 0.], [1./1.414], 1./1.414]]] | |||||
| //>>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]] | |||||
| //>>> # loss = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1)) | |||||
| //>>> # = -((0. + 0.) + (0.5 + 0.5)) / 2 | |||||
| //-0.5 | |||||
| var loss = keras.losses.CosineSimilarity(axis : 1); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)(-0.49999997f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _Sample_Weight() | |||||
| { | |||||
| //>>> # Calling with 'sample_weight'. | |||||
| //>>> cosine_loss(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy() | |||||
| //- 0.0999 | |||||
| var loss = keras.losses.CosineSimilarity(); | |||||
| var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f }); | |||||
| Assert.AreEqual((NDArray) (- 0.099999994f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _SUM() | |||||
| { | |||||
| //>>> # Using 'sum' reduction type. | |||||
| //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1, | |||||
| //... reduction = tf.keras.losses.Reduction.SUM) | |||||
| //>>> cosine_loss(y_true, y_pred).numpy() | |||||
| //- 0.999 | |||||
| var loss = keras.losses.CosineSimilarity(axis: 1,reduction : ReductionV2.SUM); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)(-0.99999994f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _None() | |||||
| { | |||||
| //>>> # Using 'none' reduction type. | |||||
| //>>> cosine_loss = tf.keras.losses.CosineSimilarity(axis = 1, | |||||
| //... reduction = tf.keras.losses.Reduction.NONE) | |||||
| //>>> cosine_loss(y_true, y_pred).numpy() | |||||
| //array([-0., -0.999], dtype = float32) | |||||
| var loss = keras.losses.CosineSimilarity(axis :1, reduction: ReductionV2.NONE); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)new float[] { -0f, -0.99999994f }, call.numpy()); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,72 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.UnitTest.Keras | |||||
| { | |||||
| [TestClass] | |||||
| public class Huber | |||||
| { | |||||
| //https://keras.io/api/losses/regression_losses/#meansquarederror-class | |||||
| NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; | |||||
| NDArray y_pred_float = new float[,] { { 0.6f, 0.4f }, { 0.4f, 0.6f } }; | |||||
| [TestMethod] | |||||
| public void _Default() | |||||
| { | |||||
| //>>> # Using 'auto'/'sum_over_batch_size' reduction type. | |||||
| //>>> h = tf.keras.losses.Huber() | |||||
| //>>> h(y_true, y_pred).numpy() | |||||
| //0.155 | |||||
| var loss = keras.losses.Huber(); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)0.155f, call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _Sample_Weight() | |||||
| { | |||||
| //>>> # Calling with 'sample_weight'. | |||||
| //>>> h(y_true, y_pred, sample_weight =[1, 0]).numpy() | |||||
| //0.09 | |||||
| var loss = keras.losses.Huber(); | |||||
| var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.1f, 0.0f }); | |||||
| Assert.AreEqual((NDArray)0.009000001f, call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _SUM() | |||||
| { | |||||
| //>>> # Using 'sum' reduction type. | |||||
| //>>> h = tf.keras.losses.Huber( | |||||
| //... reduction = tf.keras.losses.Reduction.SUM) | |||||
| //>>> h(y_true, y_pred).numpy() | |||||
| //0.31 | |||||
| var loss = keras.losses.Huber(reduction : ReductionV2.SUM); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)0.31f, call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _None() | |||||
| { | |||||
| //>>> # Using 'none' reduction type. | |||||
| //>>> h = tf.keras.losses.Huber( | |||||
| //... reduction = tf.keras.losses.Reduction.NONE) | |||||
| //>>> h(y_true, y_pred).numpy() | |||||
| //array([0.18, 0.13], dtype = float32) | |||||
| var loss = keras.losses.Huber(reduction: ReductionV2.NONE); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)new float[] { 0.18f, 0.13000001f }, call.numpy()); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,72 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.UnitTest.Keras | |||||
| { | |||||
| [TestClass] | |||||
| public class LogCosh | |||||
| { | |||||
| //https://keras.io/api/losses/regression_losses/#meansquarederror-class | |||||
| NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; | |||||
| NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 0.0f, 0.0f } }; | |||||
| [TestMethod] | |||||
| public void _Default() | |||||
| { | |||||
| //>>> # Using 'auto'/'sum_over_batch_size' reduction type. | |||||
| //>>> l = tf.keras.losses.LogCosh() | |||||
| //>>> l(y_true, y_pred).numpy() | |||||
| //0.108 | |||||
| var loss = keras.losses.LogCosh(); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)0.1084452f, call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _Sample_Weight() | |||||
| { | |||||
| //>>> # Calling with 'sample_weight'. | |||||
| //>>> l(y_true, y_pred, sample_weight =[0.8, 0.2]).numpy() | |||||
| //0.087 | |||||
| var loss = keras.losses.LogCosh(); | |||||
| var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.8f, 0.2f }); | |||||
| Assert.AreEqual((NDArray)0.08675616f, call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _SUM() | |||||
| { | |||||
| //>>> # Using 'sum' reduction type. | |||||
| //>>> l = tf.keras.losses.LogCosh( | |||||
| //... reduction = tf.keras.losses.Reduction.SUM) | |||||
| //>>> l(y_true, y_pred).numpy() | |||||
| //0.217 | |||||
| var loss = keras.losses.LogCosh(reduction : ReductionV2.SUM); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)0.2168904f, call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _None() | |||||
| { | |||||
| //>>> # Using 'none' reduction type. | |||||
| //>>> l = tf.keras.losses.LogCosh( | |||||
| //... reduction = tf.keras.losses.Reduction.NONE) | |||||
| //>>> l(y_true, y_pred).numpy() | |||||
| //array([0.217, 0.], dtype = float32) | |||||
| var loss = keras.losses.LogCosh(reduction: ReductionV2.NONE); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)new float[] { 0.2168904f, 0.0f }, call.numpy()); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,73 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.UnitTest.Keras | |||||
| { | |||||
| [TestClass] | |||||
| public class MeanAbsoluteError | |||||
| { | |||||
| //https://keras.io/api/losses/regression_losses/ | |||||
| NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; | |||||
| NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; | |||||
| [TestMethod] | |||||
| public void _Default() | |||||
| { | |||||
| //>>> # Using 'auto'/'sum_over_batch_size' reduction type. | |||||
| //>>> mae = tf.keras.losses.MeanAbsoluteError() | |||||
| //>>> mae(y_true, y_pred).numpy() | |||||
| //0.5 | |||||
| var loss = keras.losses.MeanAbsoluteError(); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)(0.5f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _Sample_Weight() | |||||
| { | |||||
| //>>> # Calling with 'sample_weight'. | |||||
| //>>> mae(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy() | |||||
| //0.25 | |||||
| var loss = keras.losses.MeanAbsoluteError(); | |||||
| var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f }); | |||||
| Assert.AreEqual((NDArray)(0.25f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _SUM() | |||||
| { | |||||
| //>>> # Using 'sum' reduction type. | |||||
| //>>> mae = tf.keras.losses.MeanAbsoluteError( | |||||
| //... reduction = tf.keras.losses.Reduction.SUM) | |||||
| //>>> mae(y_true, y_pred).numpy() | |||||
| //1.0 | |||||
| var loss = keras.losses.MeanAbsoluteError( reduction: ReductionV2.SUM); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)(1.0f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _None() | |||||
| { | |||||
| //>>> # Using 'none' reduction type. | |||||
| //>>> mae = tf.keras.losses.MeanAbsoluteError( | |||||
| //... reduction = tf.keras.losses.Reduction.NONE) | |||||
| //>>> mae(y_true, y_pred).numpy() | |||||
| //array([0.5, 0.5], dtype = float32) | |||||
| var loss = keras.losses.MeanAbsoluteError(reduction: ReductionV2.NONE); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)new float[] { 0.5f, 0.5f }, call.numpy()); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,72 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.UnitTest.Keras | |||||
| { | |||||
| [TestClass] | |||||
| public class MeanAbsolutePercentageError | |||||
| { | |||||
| //https://keras.io/api/losses/regression_losses/ | |||||
| NDArray y_true_float = new float[,] { { 2.0f, 1.0f }, { 2.0f, 3.0f } }; | |||||
| NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; | |||||
| [TestMethod] | |||||
| public void _Default() | |||||
| { | |||||
| //>>> # Using 'auto'/'sum_over_batch_size' reduction type. | |||||
| //>>> mape = tf.keras.losses.MeanAbsolutePercentageError() | |||||
| //>>> mape(y_true, y_pred).numpy() | |||||
| //50. | |||||
| var loss = keras.losses.MeanAbsolutePercentageError(); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)(50f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _Sample_Weight() | |||||
| { | |||||
| //>>> # Calling with 'sample_weight'. | |||||
| //>>> mape(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy() | |||||
| //20. | |||||
| var loss = keras.losses.MeanAbsolutePercentageError(); | |||||
| var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f }); | |||||
| Assert.AreEqual((NDArray)(20f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _SUM() | |||||
| { | |||||
| //>>> # Using 'sum' reduction type. | |||||
| //>>> mape = tf.keras.losses.MeanAbsolutePercentageError( | |||||
| //... reduction = tf.keras.losses.Reduction.SUM) | |||||
| //>>> mape(y_true, y_pred).numpy() | |||||
| //100. | |||||
| var loss = keras.losses.MeanAbsolutePercentageError( reduction: ReductionV2.SUM); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)(100f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _None() | |||||
| { | |||||
| //>>> # Using 'none' reduction type. | |||||
| //>>> mape = tf.keras.losses.MeanAbsolutePercentageError( | |||||
| //... reduction = tf.keras.losses.Reduction.NONE) | |||||
| //>>> mape(y_true, y_pred).numpy() | |||||
| //array([25., 75.], dtype = float32) | |||||
| var loss = keras.losses.MeanAbsolutePercentageError(reduction: ReductionV2.NONE); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)new float[] { 25f, 75f }, call.numpy()); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,72 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.UnitTest.Keras | |||||
| { | |||||
| [TestClass] | |||||
| public class MeanSquaredLogarithmicError | |||||
| { | |||||
| //https://keras.io/api/losses/regression_losses/ | |||||
| NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; | |||||
| NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; | |||||
| [TestMethod] | |||||
| public void _Default() | |||||
| { | |||||
| //>>> # Using 'auto'/'sum_over_batch_size' reduction type. | |||||
| //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError() | |||||
| //>>> msle(y_true, y_pred).numpy() | |||||
| //0.240 | |||||
| var loss = keras.losses.MeanSquaredLogarithmicError(); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)(0.24022643f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _Sample_Weight() | |||||
| { | |||||
| //>>> # Calling with 'sample_weight'. | |||||
| //>>> msle(y_true, y_pred, sample_weight =[0.7, 0.3]).numpy() | |||||
| //0.120 | |||||
| var loss = keras.losses.MeanSquaredLogarithmicError(); | |||||
| var call = loss.Call(y_true_float, y_pred_float, sample_weight: (NDArray)new float[] { 0.7f, 0.3f }); | |||||
| Assert.AreEqual((NDArray)(0.12011322f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _SUM() | |||||
| { | |||||
| //>>> # Using 'sum' reduction type. | |||||
| //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError( | |||||
| //... reduction = tf.keras.losses.Reduction.SUM) | |||||
| //>>> msle(y_true, y_pred).numpy() | |||||
| //0.480 | |||||
| var loss = keras.losses.MeanSquaredLogarithmicError( reduction: ReductionV2.SUM); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)(0.48045287f), call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void _None() | |||||
| { | |||||
| //>>> # Using 'none' reduction type. | |||||
| //>>> msle = tf.keras.losses.MeanSquaredLogarithmicError( | |||||
| //... reduction = tf.keras.losses.Reduction.NONE) | |||||
| //>>> msle(y_true, y_pred).numpy() | |||||
| //array([0.240, 0.240], dtype = float32) | |||||
| var loss = keras.losses.MeanSquaredLogarithmicError(reduction: ReductionV2.NONE); | |||||
| var call = loss.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)new float[] { 0.24022643f, 0.24022643f }, call.numpy()); | |||||
| } | |||||
| } | |||||
| } | |||||