| @@ -6,5 +6,36 @@ namespace Tensorflow.Keras.Losses | |||||
| { | { | ||||
| public abstract class Loss | public abstract class Loss | ||||
| { | { | ||||
| public static Tensor mean_squared_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor mean_absolute_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor mean_squared_logarithmic_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor _maybe_convert_labels(Tensor y_true) => throw new NotImplementedException(); | |||||
| public static Tensor squared_hinge(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor hinge(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor categorical_hinge(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor huber_loss(Tensor y_true, Tensor y_pred, float delta = 1) => throw new NotImplementedException(); | |||||
| public static Tensor logcosh(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0) => throw new NotImplementedException(); | |||||
| public static Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float axis = -1) => throw new NotImplementedException(); | |||||
| public static Tensor binary_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0) => throw new NotImplementedException(); | |||||
| public static Tensor kullback_leibler_divergence(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor poisson(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor cosine_similarity(Tensor y_true, Tensor y_pred, int axis = -1) => throw new NotImplementedException(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,41 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class AUC | |||||
| public class AUC : Metric | |||||
| { | { | ||||
| public AUC(int num_thresholds= 200, string curve= "ROC", string summation_method= "interpolation", | |||||
| string name= null, string dtype= null, float thresholds= 0.5f, | |||||
| bool multi_label= false, Tensor label_weights= null) : base(name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| private void _build(TensorShape shape) => throw new NotImplementedException(); | |||||
| public Tensor interpolate_pr_auc() => throw new NotImplementedException(); | |||||
| public override Tensor result() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void update_state(Args args, KwArgs kwargs) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void reset_states() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Hashtable get_config() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class Accuracy | |||||
| public class Accuracy : MeanMetricWrapper | |||||
| { | { | ||||
| public Accuracy(string name = "accuracy", string dtype = null) | |||||
| : base(Metric.accuracy, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,16 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class BinaryAccuracy | |||||
| public class BinaryAccuracy : MeanMetricWrapper | |||||
| { | { | ||||
| public BinaryAccuracy(string name = "binary_accuracy", string dtype = null, float threshold = 0.5f) | |||||
| : base(Fn, name, dtype) | |||||
| { | |||||
| } | |||||
| internal static Tensor Fn(Tensor y_true, Tensor y_pred) | |||||
| { | |||||
| return Metric.binary_accuracy(y_true, y_pred); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class CategoricalAccuracy | |||||
| public class CategoricalAccuracy : MeanMetricWrapper | |||||
| { | { | ||||
| public CategoricalAccuracy(string name = "categorical_accuracy", string dtype = null) | |||||
| : base(Metric.categorical_accuracy, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class CategoricalHinge | |||||
| public class CategoricalHinge : MeanMetricWrapper | |||||
| { | { | ||||
| public CategoricalHinge(string name = "categorical_hinge", string dtype = null) | |||||
| : base(Losses.Loss.categorical_hinge, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,16 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class CosineSimilarity | |||||
| public class CosineSimilarity : MeanMetricWrapper | |||||
| { | { | ||||
| public CosineSimilarity(string name = "cosine_similarity", string dtype = null, int axis = -1) | |||||
| : base(Fn, name, dtype) | |||||
| { | |||||
| } | |||||
| internal static Tensor Fn(Tensor y_true, Tensor y_pred) | |||||
| { | |||||
| return Metric.cosine_proximity(y_true, y_pred); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class FalseNegatives | |||||
| public class FalseNegatives : _ConfusionMatrixConditionCount | |||||
| { | { | ||||
| public FalseNegatives(float thresholds = 0.5F, string name = null, string dtype = null) | |||||
| : base(Utils.MetricsUtils.ConfusionMatrix.FALSE_NEGATIVES, thresholds, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class FalsePositives | |||||
| public class FalsePositives : _ConfusionMatrixConditionCount | |||||
| { | { | ||||
| public FalsePositives(float thresholds = 0.5F, string name = null, string dtype = null) | |||||
| : base(Utils.MetricsUtils.ConfusionMatrix.FALSE_POSITIVES, thresholds, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class Hinge | |||||
| public class Hinge : MeanMetricWrapper | |||||
| { | { | ||||
| public Hinge(string name = "hinge", string dtype = null) | |||||
| : base(Losses.Loss.hinge, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class LogCoshError | |||||
| public class LogCoshError : MeanMetricWrapper | |||||
| { | { | ||||
| public LogCoshError(string name = "logcosh", string dtype = null) | |||||
| : base(Losses.Loss.logcosh, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,12 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class Mean | |||||
| public class Mean : Reduce | |||||
| { | { | ||||
| public Mean(string name, string dtype = null) | |||||
| : base(Reduction.MEAN, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class MeanAbsoluteError | |||||
| public class MeanAbsoluteError : MeanMetricWrapper | |||||
| { | { | ||||
| public MeanAbsoluteError(string name = "mean_absolute_error", string dtype = null) | |||||
| : base(Losses.Loss.mean_absolute_error, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class MeanAbsolutePercentageError | |||||
| public class MeanAbsolutePercentageError : MeanMetricWrapper | |||||
| { | { | ||||
| public MeanAbsolutePercentageError(string name = "mean_absolute_percentage_error", string dtype = null) | |||||
| : base(Losses.Loss.mean_absolute_percentage_error, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,25 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class MeanMetricWrapper | |||||
| public class MeanMetricWrapper : Mean | |||||
| { | { | ||||
| public MeanMetricWrapper(Func<Tensor, Tensor, Tensor> fn, string name, string dtype = null) : base(name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Tensor result() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Hashtable get_config() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,30 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class MeanRelativeError | |||||
| public class MeanRelativeError : Metric | |||||
| { | { | ||||
| public MeanRelativeError(Tensor normalizer, string name, string dtype) : base(name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Tensor result() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void update_state(Args args, KwArgs kwargs) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Hashtable get_config() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class MeanSquaredError | |||||
| public class MeanSquaredError : MeanMetricWrapper | |||||
| { | { | ||||
| public MeanSquaredError(string name = "mean_squared_error", string dtype = null) | |||||
| : base(Losses.Loss.mean_squared_error, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class MeanSquaredLogarithmicError | |||||
| public class MeanSquaredLogarithmicError : MeanMetricWrapper | |||||
| { | { | ||||
| public MeanSquaredLogarithmicError(string name = "mean_squared_logarithmic_error", string dtype = null) | |||||
| : base(Losses.Loss.mean_squared_logarithmic_error, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -35,5 +35,29 @@ namespace Tensorflow.Keras.Metrics | |||||
| public void add_weight(string name, TensorShape shape= null, VariableAggregation aggregation= VariableAggregation.Sum, | public void add_weight(string name, TensorShape shape= null, VariableAggregation aggregation= VariableAggregation.Sum, | ||||
| VariableSynchronization synchronization = VariableSynchronization.OnRead, Initializers.Initializer initializer= null, | VariableSynchronization synchronization = VariableSynchronization.OnRead, Initializers.Initializer initializer= null, | ||||
| string dtype= null) => throw new NotImplementedException(); | string dtype= null) => throw new NotImplementedException(); | ||||
| public static Tensor accuracy(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor binary_accuracy(Tensor y_true, Tensor y_pred, float threshold = 0.5f) => throw new NotImplementedException(); | |||||
| public static Tensor categorical_accuracy(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred) => throw new NotImplementedException(); | |||||
| public static Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5) => throw new NotImplementedException(); | |||||
| public static Tensor sparse_top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5) => throw new NotImplementedException(); | |||||
| public static Tensor cosine_proximity(Tensor y_true, Tensor y_pred, int axis = -1) => throw new NotImplementedException(); | |||||
| public static Metric clone_metric(Metric metric) => throw new NotImplementedException(); | |||||
| public static Metric[] clone_metrics(Metric[] metric) => throw new NotImplementedException(); | |||||
| public static string serialize(Metric metric) => throw new NotImplementedException(); | |||||
| public static Metric deserialize(string config, object custom_objects = null) => throw new NotImplementedException(); | |||||
| public static Metric get(object identifier) => throw new NotImplementedException(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class Poisson | |||||
| public class Poisson : MeanMetricWrapper | |||||
| { | { | ||||
| public Poisson(string name = "logcosh", string dtype = null) | |||||
| : base(Losses.Loss.logcosh, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,41 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class Precision | |||||
| public class Precision : Metric | |||||
| { | { | ||||
| public Precision(float? thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public Precision(float[] thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Tensor result() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void update_state(Args args, KwArgs kwargs) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void reset_states() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Hashtable get_config() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,25 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class PrecisionAtRecall | |||||
| public class PrecisionAtRecall : SensitivitySpecificityBase | |||||
| { | { | ||||
| public PrecisionAtRecall(float recall, int num_thresholds = 200, string name = null, string dtype = null) : base(recall, num_thresholds, name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Tensor result() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Hashtable get_config() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,41 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class Recall | |||||
| public class Recall : Metric | |||||
| { | { | ||||
| public Recall(float? thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public Recall(float[] thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Tensor result() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void update_state(Args args, KwArgs kwargs) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void reset_states() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Hashtable get_config() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,22 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class Reduce | |||||
| public class Reduce : Metric | |||||
| { | { | ||||
| public Reduce(string reduction, string name, string dtype= null) | |||||
| : base(name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Tensor result() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void update_state(Args args, KwArgs kwargs) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class RootMeanSquaredError | |||||
| public class RootMeanSquaredError : Mean | |||||
| { | { | ||||
| public RootMeanSquaredError(string name = "root_mean_squared_error", string dtype = null) | |||||
| : base(name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,25 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class SensitivityAtSpecificity | |||||
| public class SensitivityAtSpecificity : SensitivitySpecificityBase | |||||
| { | { | ||||
| public SensitivityAtSpecificity(float specificity, int num_thresholds = 200, string name = null, string dtype = null) : base(specificity, num_thresholds, name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Tensor result() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Hashtable get_config() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,26 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class SensitivitySpecificityBase | |||||
| public class SensitivitySpecificityBase : Metric | |||||
| { | { | ||||
| public SensitivitySpecificityBase(float value, int num_thresholds= 200, string name = null, string dtype = null) : base(name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Tensor result() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void update_state(Args args, KwArgs kwargs) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void reset_states() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,12 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class SparseCategoricalAccuracy | |||||
| public class SparseCategoricalAccuracy : MeanMetricWrapper | |||||
| { | { | ||||
| public SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", string dtype = null) | |||||
| : base(Metric.sparse_categorical_accuracy, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,17 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class SparseTopKCategoricalAccuracy | |||||
| public class SparseTopKCategoricalAccuracy : MeanMetricWrapper | |||||
| { | { | ||||
| public SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", string dtype = null) | |||||
| : base(Fn, name, dtype) | |||||
| { | |||||
| } | |||||
| internal static Tensor Fn(Tensor y_true, Tensor y_pred) | |||||
| { | |||||
| return Metric.sparse_top_k_categorical_accuracy(y_true, y_pred); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class SquaredHinge | |||||
| public class SquaredHinge : MeanMetricWrapper | |||||
| { | { | ||||
| public SquaredHinge(string name = "squared_hinge", string dtype = null) | |||||
| : base(Losses.Loss.squared_hinge, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class Sum | |||||
| public class Sum : Reduce | |||||
| { | { | ||||
| public Sum(string name, string dtype = null) | |||||
| : base(Reduction.SUM, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,16 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class TopKCategoricalAccuracy | |||||
| public class TopKCategoricalAccuracy : MeanMetricWrapper | |||||
| { | { | ||||
| public TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", string dtype = null) | |||||
| : base(Fn, name, dtype) | |||||
| { | |||||
| } | |||||
| internal static Tensor Fn(Tensor y_true, Tensor y_pred) | |||||
| { | |||||
| return Metric.top_k_categorical_accuracy(y_true, y_pred); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class TrueNegatives | |||||
| public class TrueNegatives : _ConfusionMatrixConditionCount | |||||
| { | { | ||||
| public TrueNegatives(float thresholds = 0.5F, string name = null, string dtype = null) | |||||
| : base(Utils.MetricsUtils.ConfusionMatrix.TRUE_NEGATIVES, thresholds, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -4,7 +4,11 @@ using System.Text; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class TruePositives | |||||
| public class TruePositives : _ConfusionMatrixConditionCount | |||||
| { | { | ||||
| public TruePositives(float thresholds = 0.5F, string name = null, string dtype = null) | |||||
| : base(Utils.MetricsUtils.ConfusionMatrix.TRUE_POSITIVES, thresholds, name, dtype) | |||||
| { | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,10 +1,37 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| using static Tensorflow.Keras.Utils.MetricsUtils; | |||||
| namespace Tensorflow.Keras.Metrics | namespace Tensorflow.Keras.Metrics | ||||
| { | { | ||||
| class _ConfusionMatrixConditionCount | |||||
| public class _ConfusionMatrixConditionCount : Metric | |||||
| { | { | ||||
| public _ConfusionMatrixConditionCount(string confusion_matrix_cond, float thresholds= 0.5f, string name= null, string dtype= null) | |||||
| : base(name, dtype) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Tensor result() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void update_state(Args args, KwArgs kwargs) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override void reset_states() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| public override Hashtable get_config() | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -28,7 +28,7 @@ namespace Tensorflow.Keras | |||||
| public static string serialize(Optimizer optimizer) => throw new NotImplementedException(); | public static string serialize(Optimizer optimizer) => throw new NotImplementedException(); | ||||
| public static string deserialize(string config, object custom_objects = null) => throw new NotImplementedException(); | |||||
| public static Optimizer deserialize(string config, object custom_objects = null) => throw new NotImplementedException(); | |||||
| public static Optimizer get(object identifier) => throw new NotImplementedException(); | public static Optimizer get(object identifier) => throw new NotImplementedException(); | ||||