Browse Source

Update Loss.cs

pull/682/head
dataangel GitHub 5 years ago
parent
commit
d203323040
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 3 deletions
  1. +4
    -3
      src/TensorFlowNET.Keras/Losses/Loss.cs

+ 4
- 3
src/TensorFlowNET.Keras/Losses/Loss.cs View File

@@ -1,4 +1,4 @@
using System;
using System;
using Tensorflow.Keras.Utils;

namespace Tensorflow.Keras.Losses
@@ -31,10 +31,11 @@ namespace Tensorflow.Keras.Losses
throw new NotImplementedException("");
}

public Tensor Call(Tensor y_true, Tensor y_pred)
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{
var losses = Apply(y_true, y_pred, from_logits: from_logits);
return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE);

return losses_utils.compute_weighted_loss(losses, reduction: this.reduction == null?ReductionV2.SUM_OVER_BATCH_SIZE : this.reduction, sample_weight: sample_weight);
}

void _set_name_scope()


Loading…
Cancel
Save