You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

MetricsApi.cs 2.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. using static Tensorflow.KerasApi;
  2. namespace Tensorflow.Keras.Metrics
  3. {
  4. public class MetricsApi
  5. {
  6. public Tensor binary_accuracy(Tensor y_true, Tensor y_pred)
  7. {
  8. float threshold = 0.5f;
  9. y_pred = math_ops.cast(y_pred > threshold, y_pred.dtype);
  10. return keras.backend.mean(math_ops.equal(y_true, y_pred), axis: -1);
  11. }
  12. public Tensor categorical_accuracy(Tensor y_true, Tensor y_pred)
  13. {
  14. var eql = math_ops.equal(math_ops.argmax(y_true, -1), math_ops.argmax(y_pred, -1));
  15. return math_ops.cast(eql, TF_DataType.TF_FLOAT);
  16. }
  17. /// <summary>
  18. /// Calculates how often predictions matches integer labels.
  19. /// </summary>
  20. /// <param name="y_true">Integer ground truth values.</param>
  21. /// <param name="y_pred">The prediction values.</param>
  22. /// <returns>Sparse categorical accuracy values.</returns>
  23. public Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred)
  24. {
  25. var y_pred_rank = y_pred.TensorShape.ndim;
  26. var y_true_rank = y_true.TensorShape.ndim;
  27. // If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
  28. if (y_true_rank != -1 && y_pred_rank != -1
  29. && y_true.shape.Length == y_pred.shape.Length)
  30. y_true = array_ops.squeeze(y_true, axis: new[] { -1 });
  31. y_pred = math_ops.argmax(y_pred, -1);
  32. // If the predicted output and actual output types don't match, force cast them
  33. // to match.
  34. if (y_pred.dtype != y_true.dtype)
  35. y_pred = math_ops.cast(y_pred, y_true.dtype);
  36. return math_ops.cast(math_ops.equal(y_true, y_pred), TF_DataType.TF_FLOAT);
  37. }
  38. public Tensor mean_absolute_error(Tensor y_true, Tensor y_pred)
  39. {
  40. y_true = math_ops.cast(y_true, y_pred.dtype);
  41. return keras.backend.mean(math_ops.abs(y_pred - y_true), axis: -1);
  42. }
  43. }
  44. }