You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Precision.cs 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. namespace Tensorflow.Keras.Metrics;
  2. public class Precision : Metric
  3. {
  4. Tensor _thresholds;
  5. int _top_k;
  6. int _class_id;
  7. IVariableV1 true_positives;
  8. IVariableV1 false_positives;
  9. bool _thresholds_distributed_evenly;
  10. public Precision(float thresholds = 0.5f, int top_k = 0, int class_id = 0, string name = "recall", TF_DataType dtype = TF_DataType.TF_FLOAT)
  11. : base(name: name, dtype: dtype)
  12. {
  13. _thresholds = constant_op.constant(new float[] { thresholds });
  14. _top_k = top_k;
  15. _class_id = class_id;
  16. true_positives = add_weight("true_positives", shape: 1, initializer: tf.initializers.zeros_initializer());
  17. false_positives = add_weight("false_positives", shape: 1, initializer: tf.initializers.zeros_initializer());
  18. }
  19. public override Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
  20. {
  21. return metrics_utils.update_confusion_matrix_variables(
  22. new Dictionary<string, IVariableV1>
  23. {
  24. { "tp", true_positives },
  25. { "fp", false_positives },
  26. },
  27. y_true,
  28. y_pred,
  29. thresholds: _thresholds,
  30. thresholds_distributed_evenly: _thresholds_distributed_evenly,
  31. top_k: _top_k,
  32. class_id: _class_id,
  33. sample_weight: sample_weight);
  34. }
  35. public override Tensor result()
  36. {
  37. var result = tf.divide(true_positives.AsTensor(), tf.add(true_positives, false_positives));
  38. return _thresholds.size == 1 ? result[0] : result;
  39. }
  40. public override void reset_states()
  41. {
  42. var num_thresholds = (int)_thresholds.size;
  43. keras.backend.batch_set_value(
  44. new List<(IVariableV1, NDArray)>
  45. {
  46. (true_positives, np.zeros(num_thresholds)),
  47. (false_positives, np.zeros(num_thresholds))
  48. });
  49. }
  50. }