You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MeanSquaredError.cs 880 B

5 years ago
1234567891011121314151617181920212223242526
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using static Tensorflow.Binding;
  5. using static Tensorflow.KerasApi;
  6. namespace Tensorflow.Keras.Losses
  7. {
  8. public class MeanSquaredError : LossFunctionWrapper, ILossFunc
  9. {
  10. public MeanSquaredError(
  11. string reduction = ReductionV2.AUTO,
  12. string name = "mean_squared_error") :
  13. base(reduction: reduction,
  14. name: name)
  15. {
  16. }
  17. public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
  18. {
  19. Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
  20. Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
  21. return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1);
  22. }
  23. }
  24. }