You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Loss.cs 1.6 kB

5 years ago
5 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. using System;
  2. using Tensorflow.Keras.Utils;
  3. namespace Tensorflow.Keras.Losses
  4. {
  5. /// <summary>
  6. /// Loss base class.
  7. /// </summary>
  8. public abstract class Loss
  9. {
  10. protected string reduction;
  11. protected string name;
  12. bool _allow_sum_over_batch_size;
  13. protected bool from_logits = false;
  14. string _name_scope;
  15. public string Reduction => reduction;
  16. public string Name => name;
  17. public Loss(string reduction = ReductionV2.AUTO,
  18. string name = null,
  19. bool from_logits = false)
  20. {
  21. this.reduction = reduction == null ? ReductionV2.SUM_OVER_BATCH_SIZE : reduction;
  22. this.name = name;
  23. this.from_logits = from_logits;
  24. _allow_sum_over_batch_size = false;
  25. }
  26. public virtual Tensor Apply(Tensor y_true, Tensor y_pred, bool from_logits = false, int axis = -1)
  27. {
  28. throw new NotImplementedException("");
  29. }
  30. public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
  31. {
  32. var losses = Apply(y_true, y_pred, from_logits: from_logits);
  33. var reduction = GetReduction();
  34. return losses_utils.compute_weighted_loss(losses, reduction: reduction, sample_weight: sample_weight);
  35. }
  36. string GetReduction()
  37. {
  38. return reduction switch
  39. {
  40. ReductionV2.AUTO => ReductionV2.SUM_OVER_BATCH_SIZE,
  41. _ => reduction
  42. };
  43. }
  44. void _set_name_scope()
  45. {
  46. _name_scope = name;
  47. }
  48. }
  49. }