|
|
|
@@ -1,4 +1,4 @@ |
|
|
|
using System; |
|
|
|
using System; |
|
|
|
using Tensorflow.Keras.Utils; |
|
|
|
|
|
|
|
namespace Tensorflow.Keras.Losses |
|
|
|
@@ -31,10 +31,11 @@ namespace Tensorflow.Keras.Losses |
|
|
|
throw new NotImplementedException(""); |
|
|
|
} |
|
|
|
|
|
|
|
public Tensor Call(Tensor y_true, Tensor y_pred) |
|
|
|
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) |
|
|
|
{ |
|
|
|
var losses = Apply(y_true, y_pred, from_logits: from_logits); |
|
|
|
return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); |
|
|
|
|
|
|
|
return losses_utils.compute_weighted_loss(losses, reduction: this.reduction == null?ReductionV2.SUM_OVER_BATCH_SIZE : this.reduction, sample_weight: sample_weight); |
|
|
|
} |
|
|
|
|
|
|
|
void _set_name_scope() |
|
|
|
|