diff --git a/src/TensorFlowNET.Keras/Losses/Loss.cs b/src/TensorFlowNET.Keras/Losses/Loss.cs index 49246027..8acee5ba 100644 --- a/src/TensorFlowNET.Keras/Losses/Loss.cs +++ b/src/TensorFlowNET.Keras/Losses/Loss.cs @@ -6,5 +6,36 @@ namespace Tensorflow.Keras.Losses { public abstract class Loss { + public static Tensor mean_squared_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor mean_absolute_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor mean_squared_logarithmic_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor _maybe_convert_labels(Tensor y_true) => throw new NotImplementedException(); + + public static Tensor squared_hinge(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor hinge(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor categorical_hinge(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor huber_loss(Tensor y_true, Tensor y_pred, float delta = 1) => throw new NotImplementedException(); + + public static Tensor logcosh(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0) => throw new NotImplementedException(); + + public static Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float axis = -1) => throw new NotImplementedException(); + + public static Tensor binary_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0) => throw new NotImplementedException(); + + public static Tensor kullback_leibler_divergence(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor poisson(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor cosine_similarity(Tensor y_true, Tensor y_pred, int axis = -1) => throw new NotImplementedException(); } } diff --git a/src/TensorFlowNET.Keras/Metrics/AUC.cs b/src/TensorFlowNET.Keras/Metrics/AUC.cs index e7e03626..c34f61c8 100644 --- a/src/TensorFlowNET.Keras/Metrics/AUC.cs +++ b/src/TensorFlowNET.Keras/Metrics/AUC.cs @@ -1,10 +1,41 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.Metrics { - class AUC + public class AUC : Metric { + public AUC(int num_thresholds= 200, string curve= "ROC", string summation_method= "interpolation", + string name= null, string dtype= null, float thresholds= 0.5f, + bool multi_label= false, Tensor label_weights= null) : base(name, dtype) + { + throw new NotImplementedException(); + } + + private void _build(TensorShape shape) => throw new NotImplementedException(); + + public Tensor interpolate_pr_auc() => throw new NotImplementedException(); + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override void update_state(Args args, KwArgs kwargs) + { + throw new NotImplementedException(); + } + + public override void reset_states() + { + throw new NotImplementedException(); + } + + public override Hashtable get_config() + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/Accuracy.cs b/src/TensorFlowNET.Keras/Metrics/Accuracy.cs index 66a774e5..cb58ae91 100644 --- a/src/TensorFlowNET.Keras/Metrics/Accuracy.cs +++ b/src/TensorFlowNET.Keras/Metrics/Accuracy.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class Accuracy + public class Accuracy : MeanMetricWrapper { + public Accuracy(string name = "accuracy", string dtype = null) + : base(Metric.accuracy, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs index b6e564ff..682ed236 100644 --- a/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs +++ b/src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs @@ -4,7 +4,16 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class BinaryAccuracy + public class BinaryAccuracy : MeanMetricWrapper { + public BinaryAccuracy(string name = "binary_accuracy", string dtype = null, float threshold = 0.5f) + : base(Fn, name, dtype) + { + } + + internal static Tensor Fn(Tensor y_true, Tensor y_pred) + { + return Metric.binary_accuracy(y_true, y_pred); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs index cdc62dd9..64b31f64 100644 --- a/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs +++ b/src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class CategoricalAccuracy + public class CategoricalAccuracy : MeanMetricWrapper { + public CategoricalAccuracy(string name = "categorical_accuracy", string dtype = null) + : base(Metric.categorical_accuracy, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/CategoricalHinge.cs b/src/TensorFlowNET.Keras/Metrics/CategoricalHinge.cs index 2f50ea76..1f82d725 100644 --- a/src/TensorFlowNET.Keras/Metrics/CategoricalHinge.cs +++ b/src/TensorFlowNET.Keras/Metrics/CategoricalHinge.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class CategoricalHinge + public class CategoricalHinge : MeanMetricWrapper { + public CategoricalHinge(string name = "categorical_hinge", string dtype = null) + : base(Losses.Loss.categorical_hinge, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs b/src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs index c1dd618a..abce27c8 100644 --- a/src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs +++ b/src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs @@ -4,7 +4,16 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class CosineSimilarity + public class CosineSimilarity : MeanMetricWrapper { + public CosineSimilarity(string name = "cosine_similarity", string dtype = null, int axis = -1) + : base(Fn, name, dtype) + { + } + + internal static Tensor Fn(Tensor y_true, Tensor y_pred) + { + return Metric.cosine_proximity(y_true, y_pred); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/FalseNegatives.cs b/src/TensorFlowNET.Keras/Metrics/FalseNegatives.cs index 075a8373..fb27484e 100644 --- a/src/TensorFlowNET.Keras/Metrics/FalseNegatives.cs +++ b/src/TensorFlowNET.Keras/Metrics/FalseNegatives.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class FalseNegatives + public class FalseNegatives : _ConfusionMatrixConditionCount { + public FalseNegatives(float thresholds = 0.5F, string name = null, string dtype = null) + : base(Utils.MetricsUtils.ConfusionMatrix.FALSE_NEGATIVES, thresholds, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/FalsePositives.cs b/src/TensorFlowNET.Keras/Metrics/FalsePositives.cs index fc7ad152..1b97e556 100644 --- a/src/TensorFlowNET.Keras/Metrics/FalsePositives.cs +++ b/src/TensorFlowNET.Keras/Metrics/FalsePositives.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class FalsePositives + public class FalsePositives : _ConfusionMatrixConditionCount { + public FalsePositives(float thresholds = 0.5F, string name = null, string dtype = null) + : base(Utils.MetricsUtils.ConfusionMatrix.FALSE_POSITIVES, thresholds, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/Hinge.cs b/src/TensorFlowNET.Keras/Metrics/Hinge.cs index f8d7eef2..21ebe067 100644 --- a/src/TensorFlowNET.Keras/Metrics/Hinge.cs +++ b/src/TensorFlowNET.Keras/Metrics/Hinge.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class Hinge + public class Hinge : MeanMetricWrapper { + public Hinge(string name = "hinge", string dtype = null) + : base(Losses.Loss.hinge, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/LogCoshError.cs b/src/TensorFlowNET.Keras/Metrics/LogCoshError.cs index b2f8d040..595f4aa7 100644 --- a/src/TensorFlowNET.Keras/Metrics/LogCoshError.cs +++ b/src/TensorFlowNET.Keras/Metrics/LogCoshError.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class LogCoshError + public class LogCoshError : MeanMetricWrapper { + public LogCoshError(string name = "logcosh", string dtype = null) + : base(Losses.Loss.logcosh, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/Mean.cs b/src/TensorFlowNET.Keras/Metrics/Mean.cs index 6a61846d..64b8b5db 100644 --- a/src/TensorFlowNET.Keras/Metrics/Mean.cs +++ b/src/TensorFlowNET.Keras/Metrics/Mean.cs @@ -4,7 +4,12 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class Mean + public class Mean : Reduce { + public Mean(string name, string dtype = null) + : base(Reduction.MEAN, name, dtype) + { + } + } } diff --git a/src/TensorFlowNET.Keras/Metrics/MeanAbsoluteError.cs b/src/TensorFlowNET.Keras/Metrics/MeanAbsoluteError.cs index 04a44dcd..c326a6dd 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanAbsoluteError.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanAbsoluteError.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class MeanAbsoluteError + public class MeanAbsoluteError : MeanMetricWrapper { + public MeanAbsoluteError(string name = "mean_absolute_error", string dtype = null) + : base(Losses.Loss.mean_absolute_error, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/MeanAbsolutePercentageError.cs b/src/TensorFlowNET.Keras/Metrics/MeanAbsolutePercentageError.cs index 1d75096d..0c51a5be 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanAbsolutePercentageError.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanAbsolutePercentageError.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class MeanAbsolutePercentageError + public class MeanAbsolutePercentageError : MeanMetricWrapper { + public MeanAbsolutePercentageError(string name = "mean_absolute_percentage_error", string dtype = null) + : base(Losses.Loss.mean_absolute_percentage_error, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs index 7bfdfcdc..ccc7922b 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs @@ -1,10 +1,25 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.Metrics { - class MeanMetricWrapper + public class MeanMetricWrapper : Mean { + public MeanMetricWrapper(Func fn, string name, string dtype = null) : base(name, dtype) + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override Hashtable get_config() + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/MeanRelativeError.cs b/src/TensorFlowNET.Keras/Metrics/MeanRelativeError.cs index 8cbad89a..9ae76a6a 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanRelativeError.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanRelativeError.cs @@ -1,10 +1,30 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.Metrics { - class MeanRelativeError + public class MeanRelativeError : Metric { + public MeanRelativeError(Tensor normalizer, string name, string dtype) : base(name, dtype) + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override void update_state(Args args, KwArgs kwargs) + { + throw new NotImplementedException(); + } + + public override Hashtable get_config() + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/MeanSquaredError.cs b/src/TensorFlowNET.Keras/Metrics/MeanSquaredError.cs index f0a8ed86..e23b0f41 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanSquaredError.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanSquaredError.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class MeanSquaredError + public class MeanSquaredError : MeanMetricWrapper { + public MeanSquaredError(string name = "mean_squared_error", string dtype = null) + : base(Losses.Loss.mean_squared_error, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/MeanSquaredLogarithmicError.cs b/src/TensorFlowNET.Keras/Metrics/MeanSquaredLogarithmicError.cs index 6139f216..9f56b9d8 100644 --- a/src/TensorFlowNET.Keras/Metrics/MeanSquaredLogarithmicError.cs +++ b/src/TensorFlowNET.Keras/Metrics/MeanSquaredLogarithmicError.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class MeanSquaredLogarithmicError + public class MeanSquaredLogarithmicError : MeanMetricWrapper { + public MeanSquaredLogarithmicError(string name = "mean_squared_logarithmic_error", string dtype = null) + : base(Losses.Loss.mean_squared_logarithmic_error, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/Metric.cs b/src/TensorFlowNET.Keras/Metrics/Metric.cs index 83af1fde..10a3676b 100644 --- a/src/TensorFlowNET.Keras/Metrics/Metric.cs +++ b/src/TensorFlowNET.Keras/Metrics/Metric.cs @@ -35,5 +35,29 @@ namespace Tensorflow.Keras.Metrics public void add_weight(string name, TensorShape shape= null, VariableAggregation aggregation= VariableAggregation.Sum, VariableSynchronization synchronization = VariableSynchronization.OnRead, Initializers.Initializer initializer= null, string dtype= null) => throw new NotImplementedException(); + + public static Tensor accuracy(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor binary_accuracy(Tensor y_true, Tensor y_pred, float threshold = 0.5f) => throw new NotImplementedException(); + + public static Tensor categorical_accuracy(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); + + public static Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5) => throw new NotImplementedException(); + + public static Tensor sparse_top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5) => throw new NotImplementedException(); + + public static Tensor cosine_proximity(Tensor y_true, Tensor y_pred, int axis = -1) => throw new NotImplementedException(); + + public static Metric clone_metric(Metric metric) => throw new NotImplementedException(); + + public static Metric[] clone_metrics(Metric[] metric) => throw new NotImplementedException(); + + public static string serialize(Metric metric) => throw new NotImplementedException(); + + public static Metric deserialize(string config, object custom_objects = null) => throw new NotImplementedException(); + + public static Metric get(object identifier) => throw new NotImplementedException(); } } diff --git a/src/TensorFlowNET.Keras/Metrics/Poisson.cs b/src/TensorFlowNET.Keras/Metrics/Poisson.cs index 80bf5c52..7cdf5bd9 100644 --- a/src/TensorFlowNET.Keras/Metrics/Poisson.cs +++ b/src/TensorFlowNET.Keras/Metrics/Poisson.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class Poisson + public class Poisson : MeanMetricWrapper { + public Poisson(string name = "logcosh", string dtype = null) + : base(Losses.Loss.logcosh, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/Precision.cs b/src/TensorFlowNET.Keras/Metrics/Precision.cs index 8c5838c5..3d5c7248 100644 --- a/src/TensorFlowNET.Keras/Metrics/Precision.cs +++ b/src/TensorFlowNET.Keras/Metrics/Precision.cs @@ -1,10 +1,41 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.Metrics { - class Precision + public class Precision : Metric { + public Precision(float? thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype) + { + throw new NotImplementedException(); + } + + public Precision(float[] thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype) + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override void update_state(Args args, KwArgs kwargs) + { + throw new NotImplementedException(); + } + + public override void reset_states() + { + throw new NotImplementedException(); + } + + public override Hashtable get_config() + { + throw new NotImplementedException(); + } + } } diff --git a/src/TensorFlowNET.Keras/Metrics/PrecisionAtRecall.cs b/src/TensorFlowNET.Keras/Metrics/PrecisionAtRecall.cs index 2191c7c1..05558232 100644 --- a/src/TensorFlowNET.Keras/Metrics/PrecisionAtRecall.cs +++ b/src/TensorFlowNET.Keras/Metrics/PrecisionAtRecall.cs @@ -1,10 +1,25 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.Metrics { - class PrecisionAtRecall + public class PrecisionAtRecall : SensitivitySpecificityBase { + public PrecisionAtRecall(float recall, int num_thresholds = 200, string name = null, string dtype = null) : base(recall, num_thresholds, name, dtype) + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override Hashtable get_config() + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/Recall.cs b/src/TensorFlowNET.Keras/Metrics/Recall.cs index ac65d250..804d4461 100644 --- a/src/TensorFlowNET.Keras/Metrics/Recall.cs +++ b/src/TensorFlowNET.Keras/Metrics/Recall.cs @@ -1,10 +1,41 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.Metrics { - class Recall + public class Recall : Metric { + public Recall(float? thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype) + { + throw new NotImplementedException(); + } + + public Recall(float[] thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype) + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override void update_state(Args args, KwArgs kwargs) + { + throw new NotImplementedException(); + } + + public override void reset_states() + { + throw new NotImplementedException(); + } + + public override Hashtable get_config() + { + throw new NotImplementedException(); + } + } } diff --git a/src/TensorFlowNET.Keras/Metrics/Reduce.cs b/src/TensorFlowNET.Keras/Metrics/Reduce.cs index e383dc69..143f441e 100644 --- a/src/TensorFlowNET.Keras/Metrics/Reduce.cs +++ b/src/TensorFlowNET.Keras/Metrics/Reduce.cs @@ -4,7 +4,22 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class Reduce + public class Reduce : Metric { + public Reduce(string reduction, string name, string dtype= null) + : base(name, dtype) + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override void update_state(Args args, KwArgs kwargs) + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/RootMeanSquaredError.cs b/src/TensorFlowNET.Keras/Metrics/RootMeanSquaredError.cs index 1ca548c2..cd7a6968 100644 --- a/src/TensorFlowNET.Keras/Metrics/RootMeanSquaredError.cs +++ b/src/TensorFlowNET.Keras/Metrics/RootMeanSquaredError.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class RootMeanSquaredError + public class RootMeanSquaredError : Mean { + public RootMeanSquaredError(string name = "root_mean_squared_error", string dtype = null) + : base(name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/SensitivityAtSpecificity.cs b/src/TensorFlowNET.Keras/Metrics/SensitivityAtSpecificity.cs index 93bef74e..72793d79 100644 --- a/src/TensorFlowNET.Keras/Metrics/SensitivityAtSpecificity.cs +++ b/src/TensorFlowNET.Keras/Metrics/SensitivityAtSpecificity.cs @@ -1,10 +1,25 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; namespace Tensorflow.Keras.Metrics { - class SensitivityAtSpecificity + public class SensitivityAtSpecificity : SensitivitySpecificityBase { + public SensitivityAtSpecificity(float specificity, int num_thresholds = 200, string name = null, string dtype = null) : base(specificity, num_thresholds, name, dtype) + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override Hashtable get_config() + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/SensitivitySpecificityBase.cs b/src/TensorFlowNET.Keras/Metrics/SensitivitySpecificityBase.cs index 16aec141..7531cdbb 100644 --- a/src/TensorFlowNET.Keras/Metrics/SensitivitySpecificityBase.cs +++ b/src/TensorFlowNET.Keras/Metrics/SensitivitySpecificityBase.cs @@ -4,7 +4,26 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class SensitivitySpecificityBase + public class SensitivitySpecificityBase : Metric { + public SensitivitySpecificityBase(float value, int num_thresholds= 200, string name = null, string dtype = null) : base(name, dtype) + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override void update_state(Args args, KwArgs kwargs) + { + throw new NotImplementedException(); + } + + public override void reset_states() + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs index 32abcd2c..5a57907d 100644 --- a/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs +++ b/src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs @@ -4,7 +4,12 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class SparseCategoricalAccuracy + public class SparseCategoricalAccuracy : MeanMetricWrapper { + public SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", string dtype = null) + : base(Metric.sparse_categorical_accuracy, name, dtype) + { + } + } } diff --git a/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs index 2f1eba09..b02049ad 100644 --- a/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs +++ b/src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs @@ -4,7 +4,17 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class SparseTopKCategoricalAccuracy + public class SparseTopKCategoricalAccuracy : MeanMetricWrapper { + public SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", string dtype = null) + : base(Fn, name, dtype) + { + + } + + internal static Tensor Fn(Tensor y_true, Tensor y_pred) + { + return Metric.sparse_top_k_categorical_accuracy(y_true, y_pred); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/SquaredHinge.cs b/src/TensorFlowNET.Keras/Metrics/SquaredHinge.cs index 15041d9b..04a7bef8 100644 --- a/src/TensorFlowNET.Keras/Metrics/SquaredHinge.cs +++ b/src/TensorFlowNET.Keras/Metrics/SquaredHinge.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class SquaredHinge + public class SquaredHinge : MeanMetricWrapper { + public SquaredHinge(string name = "squared_hinge", string dtype = null) + : base(Losses.Loss.squared_hinge, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/Sum.cs b/src/TensorFlowNET.Keras/Metrics/Sum.cs index 10396867..f466a136 100644 --- a/src/TensorFlowNET.Keras/Metrics/Sum.cs +++ b/src/TensorFlowNET.Keras/Metrics/Sum.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class Sum + public class Sum : Reduce { + public Sum(string name, string dtype = null) + : base(Reduction.SUM, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs b/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs index a14e6575..e2c80fad 100644 --- a/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs +++ b/src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs @@ -4,7 +4,16 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class TopKCategoricalAccuracy + public class TopKCategoricalAccuracy : MeanMetricWrapper { + public TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", string dtype = null) + : base(Fn, name, dtype) + { + } + + internal static Tensor Fn(Tensor y_true, Tensor y_pred) + { + return Metric.top_k_categorical_accuracy(y_true, y_pred); + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/TrueNegatives.cs b/src/TensorFlowNET.Keras/Metrics/TrueNegatives.cs index 53e9a894..7e81a2fd 100644 --- a/src/TensorFlowNET.Keras/Metrics/TrueNegatives.cs +++ b/src/TensorFlowNET.Keras/Metrics/TrueNegatives.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class TrueNegatives + public class TrueNegatives : _ConfusionMatrixConditionCount { + public TrueNegatives(float thresholds = 0.5F, string name = null, string dtype = null) + : base(Utils.MetricsUtils.ConfusionMatrix.TRUE_NEGATIVES, thresholds, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/TruePositives.cs b/src/TensorFlowNET.Keras/Metrics/TruePositives.cs index 286c4be7..867049be 100644 --- a/src/TensorFlowNET.Keras/Metrics/TruePositives.cs +++ b/src/TensorFlowNET.Keras/Metrics/TruePositives.cs @@ -4,7 +4,11 @@ using System.Text; namespace Tensorflow.Keras.Metrics { - class TruePositives + public class TruePositives : _ConfusionMatrixConditionCount { + public TruePositives(float thresholds = 0.5F, string name = null, string dtype = null) + : base(Utils.MetricsUtils.ConfusionMatrix.TRUE_POSITIVES, thresholds, name, dtype) + { + } } } diff --git a/src/TensorFlowNET.Keras/Metrics/_ConfusionMatrixConditionCount.cs b/src/TensorFlowNET.Keras/Metrics/_ConfusionMatrixConditionCount.cs index ab9e2c08..3d2be961 100644 --- a/src/TensorFlowNET.Keras/Metrics/_ConfusionMatrixConditionCount.cs +++ b/src/TensorFlowNET.Keras/Metrics/_ConfusionMatrixConditionCount.cs @@ -1,10 +1,37 @@ using System; +using System.Collections; using System.Collections.Generic; using System.Text; +using static Tensorflow.Keras.Utils.MetricsUtils; namespace Tensorflow.Keras.Metrics { - class _ConfusionMatrixConditionCount + public class _ConfusionMatrixConditionCount : Metric { + public _ConfusionMatrixConditionCount(string confusion_matrix_cond, float thresholds= 0.5f, string name= null, string dtype= null) + : base(name, dtype) + { + throw new NotImplementedException(); + } + + public override Tensor result() + { + throw new NotImplementedException(); + } + + public override void update_state(Args args, KwArgs kwargs) + { + throw new NotImplementedException(); + } + + public override void reset_states() + { + throw new NotImplementedException(); + } + + public override Hashtable get_config() + { + throw new NotImplementedException(); + } } } diff --git a/src/TensorFlowNET.Keras/Optimizer/Optimizer.cs b/src/TensorFlowNET.Keras/Optimizer/Optimizer.cs index 14223c5e..ec8bd68a 100644 --- a/src/TensorFlowNET.Keras/Optimizer/Optimizer.cs +++ b/src/TensorFlowNET.Keras/Optimizer/Optimizer.cs @@ -28,7 +28,7 @@ namespace Tensorflow.Keras public static string serialize(Optimizer optimizer) => throw new NotImplementedException(); - public static string deserialize(string config, object custom_objects = null) => throw new NotImplementedException(); + public static Optimizer deserialize(string config, object custom_objects = null) => throw new NotImplementedException(); public static Optimizer get(object identifier) => throw new NotImplementedException();