diff --git a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs index b814a259..64c2c14f 100644 --- a/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs +++ b/src/TensorFlowNET.Core/Keras/Metrics/IMetricsApi.cs @@ -77,7 +77,7 @@ public interface IMetricsApi /// IMetricFunc F1Score(int num_classes, string? average = null, - float threshold = -1f, + float? threshold = null, string name = "f1_score", TF_DataType dtype = TF_DataType.TF_FLOAT); @@ -88,7 +88,7 @@ public interface IMetricsApi IMetricFunc FBetaScore(int num_classes, string? average = null, float beta = 0.1f, - float threshold = -1f, + float? threshold = null, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT); diff --git a/src/TensorFlowNET.Keras/Metrics/F1Score.cs b/src/TensorFlowNET.Keras/Metrics/F1Score.cs index c3276f3e..fc24136d 100644 --- a/src/TensorFlowNET.Keras/Metrics/F1Score.cs +++ b/src/TensorFlowNET.Keras/Metrics/F1Score.cs @@ -4,7 +4,7 @@ public class F1Score : FBetaScore { public F1Score(int num_classes, string? average = null, - float? threshold = -1f, + float? threshold = null, string name = "f1_score", TF_DataType dtype = TF_DataType.TF_FLOAT) : base(num_classes, average: average, threshold: threshold, beta: 1f, name: name, dtype: dtype) diff --git a/src/TensorFlowNET.Keras/Metrics/FBetaScore.cs b/src/TensorFlowNET.Keras/Metrics/FBetaScore.cs index ab4d00a9..39e3e9af 100644 --- a/src/TensorFlowNET.Keras/Metrics/FBetaScore.cs +++ b/src/TensorFlowNET.Keras/Metrics/FBetaScore.cs @@ -17,7 +17,7 @@ public class FBetaScore : Metric public FBetaScore(int num_classes, string? average = null, float beta = 0.1f, - float? threshold = -1f, + float? threshold = null, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT) : base(name: name, dtype: dtype) diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index b9fbe180..bd12f82a 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -86,10 +86,10 @@ public IMetricFunc CosineSimilarity(string name = "cosine_similarity", TF_DataType dtype = TF_DataType.TF_FLOAT, Axis? axis = null) => new CosineSimilarity(name: name, dtype: dtype, axis: axis ?? -1); - public IMetricFunc F1Score(int num_classes, string? average = null, float threshold = -1, string name = "f1_score", TF_DataType dtype = TF_DataType.TF_FLOAT) + public IMetricFunc F1Score(int num_classes, string? average = null, float? threshold = null, string name = "f1_score", TF_DataType dtype = TF_DataType.TF_FLOAT) => new F1Score(num_classes, average: average, threshold: threshold, name: name, dtype: dtype); - public IMetricFunc FBetaScore(int num_classes, string? average = null, float beta = 0.1F, float threshold = -1, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT) + public IMetricFunc FBetaScore(int num_classes, string? average = null, float beta = 0.1F, float? threshold = null, string name = "fbeta_score", TF_DataType dtype = TF_DataType.TF_FLOAT) => new FBetaScore(num_classes, average: average,beta: beta, threshold: threshold, name: name, dtype: dtype); public IMetricFunc TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", TF_DataType dtype = TF_DataType.TF_FLOAT)