using System; using Tensorflow.Keras.Utils; namespace Tensorflow.Keras.Losses { /// /// Loss base class. /// public abstract class Loss { protected string reduction; protected string name; bool _allow_sum_over_batch_size; protected bool from_logits = false; string _name_scope; public string Reduction => reduction; public string Name => name; public Loss(string reduction = ReductionV2.AUTO, string name = null, bool from_logits = false) { this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction; this.name = name; this.from_logits = from_logits; _allow_sum_over_batch_size = false; } public virtual Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1) { throw new NotImplementedException(""); } public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) { var losses = Apply(y_true, y_pred, from_logits: from_logits); var reduction = GetReduction(); return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight); } string GetReduction() { return reduction switch { ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE, _ => reduction }; } void _set_name_scope() { _name_scope = name; } } }