| @@ -1,4 +1,4 @@ | |||||
| using System; | |||||
| using System; | |||||
| using Tensorflow.Keras.Utils; | using Tensorflow.Keras.Utils; | ||||
| namespace Tensorflow.Keras.Losses | namespace Tensorflow.Keras.Losses | ||||
| @@ -31,10 +31,11 @@ namespace Tensorflow.Keras.Losses | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| public Tensor Call(Tensor y_true, Tensor y_pred) | |||||
| public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||||
| { | { | ||||
| var losses = Apply(y_true, y_pred, from_logits: from_logits); | var losses = Apply(y_true, y_pred, from_logits: from_logits); | ||||
| return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); | |||||
| return losses_utils.compute_weighted_loss(losses, reduction: this.reduction == null?ReductionV2.SUM_OVER_BATCH_SIZE : this.reduction, sample_weight: sample_weight); | |||||
| } | } | ||||
| void _set_name_scope() | void _set_name_scope() | ||||