From 6c4d9a8fda8b069941eee8d2c621cc66231fcae1 Mon Sep 17 00:00:00 2001 From: dataangel Date: Tue, 15 Dec 2020 00:05:25 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0Mse?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/TensorFlowNET.Keras/Losses/ILossFunc.cs | 2 +- src/TensorFlowNET.Keras/Losses/Loss.cs | 5 ++-- src/TensorFlowNET.Keras/Losses/LossesApi.cs | 3 +++ .../Losses/MeanSquaredError.cs | 26 +++++++++++++++++++ 4 files changed, 33 insertions(+), 3 deletions(-) create mode 100644 src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs diff --git a/src/TensorFlowNET.Keras/Losses/ILossFunc.cs b/src/TensorFlowNET.Keras/Losses/ILossFunc.cs index 45c39dd2..59730bd1 100644 --- a/src/TensorFlowNET.Keras/Losses/ILossFunc.cs +++ b/src/TensorFlowNET.Keras/Losses/ILossFunc.cs @@ -3,6 +3,6 @@ public interface ILossFunc { string Reduction { get; } - Tensor Call(Tensor y_true, Tensor y_pred); + Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null); } } diff --git a/src/TensorFlowNET.Keras/Losses/Loss.cs b/src/TensorFlowNET.Keras/Losses/Loss.cs index 54b2b249..857ef505 100644 --- a/src/TensorFlowNET.Keras/Losses/Loss.cs +++ b/src/TensorFlowNET.Keras/Losses/Loss.cs @@ -31,10 +31,11 @@ namespace Tensorflow.Keras.Losses throw new NotImplementedException(""); } - public Tensor Call(Tensor y_true, Tensor y_pred) + public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) { var losses = Apply(y_true, y_pred, from_logits: from_logits); - return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); + + return losses_utils.compute_weighted_loss(losses, reduction: this.reduction == null?ReductionV2.SUM_OVER_BATCH_SIZE : this.reduction, sample_weight: sample_weight); } void _set_name_scope() diff --git a/src/TensorFlowNET.Keras/Losses/LossesApi.cs b/src/TensorFlowNET.Keras/Losses/LossesApi.cs index 3e66b395..7067666a 100644 --- a/src/TensorFlowNET.Keras/Losses/LossesApi.cs +++ b/src/TensorFlowNET.Keras/Losses/LossesApi.cs @@ -7,5 +7,8 @@ public ILossFunc CategoricalCrossentropy(bool from_logits = false) => new CategoricalCrossentropy(from_logits: from_logits); + + public ILossFunc MeanSquaredError(string reduction = null) + => new MeanSquaredError(reduction: reduction); } } diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs new file mode 100644 index 00000000..1123c01e --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Losses +{ + public class MeanSquaredError : LossFunctionWrapper, ILossFunc + { + public MeanSquaredError( + string reduction = ReductionV2.AUTO, + string name = "mean_squared_error") : + base(reduction: reduction, + name: name) + { + } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); + } + } +}