diff --git a/src/TensorFlowNET.Keras/Losses/Loss.cs b/src/TensorFlowNET.Keras/Losses/Loss.cs index 54b2b249..d25d11f4 100644 --- a/src/TensorFlowNET.Keras/Losses/Loss.cs +++ b/src/TensorFlowNET.Keras/Losses/Loss.cs @@ -1,4 +1,4 @@ -using System; +using System; using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Losses @@ -31,10 +31,11 @@ namespace Tensorflow.Keras.Losses throw new NotImplementedException(""); } - public Tensor Call(Tensor y_true, Tensor y_pred) + public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) { var losses = Apply(y_true, y_pred, from_logits: from_logits); - return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); + + return losses_utils.compute_weighted_loss(losses, reduction: this.reduction == null?ReductionV2.SUM_OVER_BATCH_SIZE : this.reduction, sample_weight: sample_weight); } void _set_name_scope()