Browse Source

Mteric and losses skeletonized

tags/v0.20
Deepak Kumar 5 years ago
parent
commit
2706b79edc
36 changed files with 423 additions and 34 deletions
  1. +31
    -0
      src/TensorFlowNET.Keras/Losses/Loss.cs
  2. +32
    -1
      src/TensorFlowNET.Keras/Metrics/AUC.cs
  3. +5
    -1
      src/TensorFlowNET.Keras/Metrics/Accuracy.cs
  4. +10
    -1
      src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs
  5. +5
    -1
      src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs
  6. +5
    -1
      src/TensorFlowNET.Keras/Metrics/CategoricalHinge.cs
  7. +10
    -1
      src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs
  8. +5
    -1
      src/TensorFlowNET.Keras/Metrics/FalseNegatives.cs
  9. +5
    -1
      src/TensorFlowNET.Keras/Metrics/FalsePositives.cs
  10. +5
    -1
      src/TensorFlowNET.Keras/Metrics/Hinge.cs
  11. +5
    -1
      src/TensorFlowNET.Keras/Metrics/LogCoshError.cs
  12. +6
    -1
      src/TensorFlowNET.Keras/Metrics/Mean.cs
  13. +5
    -1
      src/TensorFlowNET.Keras/Metrics/MeanAbsoluteError.cs
  14. +5
    -1
      src/TensorFlowNET.Keras/Metrics/MeanAbsolutePercentageError.cs
  15. +16
    -1
      src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs
  16. +21
    -1
      src/TensorFlowNET.Keras/Metrics/MeanRelativeError.cs
  17. +5
    -1
      src/TensorFlowNET.Keras/Metrics/MeanSquaredError.cs
  18. +5
    -1
      src/TensorFlowNET.Keras/Metrics/MeanSquaredLogarithmicError.cs
  19. +24
    -0
      src/TensorFlowNET.Keras/Metrics/Metric.cs
  20. +5
    -1
      src/TensorFlowNET.Keras/Metrics/Poisson.cs
  21. +32
    -1
      src/TensorFlowNET.Keras/Metrics/Precision.cs
  22. +16
    -1
      src/TensorFlowNET.Keras/Metrics/PrecisionAtRecall.cs
  23. +32
    -1
      src/TensorFlowNET.Keras/Metrics/Recall.cs
  24. +16
    -1
      src/TensorFlowNET.Keras/Metrics/Reduce.cs
  25. +5
    -1
      src/TensorFlowNET.Keras/Metrics/RootMeanSquaredError.cs
  26. +16
    -1
      src/TensorFlowNET.Keras/Metrics/SensitivityAtSpecificity.cs
  27. +20
    -1
      src/TensorFlowNET.Keras/Metrics/SensitivitySpecificityBase.cs
  28. +6
    -1
      src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs
  29. +11
    -1
      src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs
  30. +5
    -1
      src/TensorFlowNET.Keras/Metrics/SquaredHinge.cs
  31. +5
    -1
      src/TensorFlowNET.Keras/Metrics/Sum.cs
  32. +10
    -1
      src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs
  33. +5
    -1
      src/TensorFlowNET.Keras/Metrics/TrueNegatives.cs
  34. +5
    -1
      src/TensorFlowNET.Keras/Metrics/TruePositives.cs
  35. +28
    -1
      src/TensorFlowNET.Keras/Metrics/_ConfusionMatrixConditionCount.cs
  36. +1
    -1
      src/TensorFlowNET.Keras/Optimizer/Optimizer.cs

+ 31
- 0
src/TensorFlowNET.Keras/Losses/Loss.cs View File

@@ -6,5 +6,36 @@ namespace Tensorflow.Keras.Losses
{ {
public abstract class Loss public abstract class Loss
{ {
public static Tensor mean_squared_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor mean_absolute_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor mean_squared_logarithmic_error(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor _maybe_convert_labels(Tensor y_true) => throw new NotImplementedException();

public static Tensor squared_hinge(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor hinge(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor categorical_hinge(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor huber_loss(Tensor y_true, Tensor y_pred, float delta = 1) => throw new NotImplementedException();

public static Tensor logcosh(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0) => throw new NotImplementedException();

public static Tensor sparse_categorical_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float axis = -1) => throw new NotImplementedException();

public static Tensor binary_crossentropy(Tensor y_true, Tensor y_pred, bool from_logits = false, float label_smoothing = 0) => throw new NotImplementedException();

public static Tensor kullback_leibler_divergence(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor poisson(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor cosine_similarity(Tensor y_true, Tensor y_pred, int axis = -1) => throw new NotImplementedException();
} }
} }

+ 32
- 1
src/TensorFlowNET.Keras/Metrics/AUC.cs View File

@@ -1,10 +1,41 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class AUC
public class AUC : Metric
{ {
public AUC(int num_thresholds= 200, string curve= "ROC", string summation_method= "interpolation",
string name= null, string dtype= null, float thresholds= 0.5f,
bool multi_label= false, Tensor label_weights= null) : base(name, dtype)
{
throw new NotImplementedException();
}

private void _build(TensorShape shape) => throw new NotImplementedException();

public Tensor interpolate_pr_auc() => throw new NotImplementedException();

public override Tensor result()
{
throw new NotImplementedException();
}

public override void update_state(Args args, KwArgs kwargs)
{
throw new NotImplementedException();
}

public override void reset_states()
{
throw new NotImplementedException();
}

public override Hashtable get_config()
{
throw new NotImplementedException();
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/Accuracy.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class Accuracy
public class Accuracy : MeanMetricWrapper
{ {
public Accuracy(string name = "accuracy", string dtype = null)
: base(Metric.accuracy, name, dtype)
{
}
} }
} }

+ 10
- 1
src/TensorFlowNET.Keras/Metrics/BinaryAccuracy.cs View File

@@ -4,7 +4,16 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class BinaryAccuracy
public class BinaryAccuracy : MeanMetricWrapper
{ {
public BinaryAccuracy(string name = "binary_accuracy", string dtype = null, float threshold = 0.5f)
: base(Fn, name, dtype)
{
}

internal static Tensor Fn(Tensor y_true, Tensor y_pred)
{
return Metric.binary_accuracy(y_true, y_pred);
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/CategoricalAccuracy.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class CategoricalAccuracy
public class CategoricalAccuracy : MeanMetricWrapper
{ {
public CategoricalAccuracy(string name = "categorical_accuracy", string dtype = null)
: base(Metric.categorical_accuracy, name, dtype)
{
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/CategoricalHinge.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class CategoricalHinge
public class CategoricalHinge : MeanMetricWrapper
{ {
public CategoricalHinge(string name = "categorical_hinge", string dtype = null)
: base(Losses.Loss.categorical_hinge, name, dtype)
{
}
} }
} }

+ 10
- 1
src/TensorFlowNET.Keras/Metrics/CosineSimilarity.cs View File

@@ -4,7 +4,16 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class CosineSimilarity
public class CosineSimilarity : MeanMetricWrapper
{ {
public CosineSimilarity(string name = "cosine_similarity", string dtype = null, int axis = -1)
: base(Fn, name, dtype)
{
}

internal static Tensor Fn(Tensor y_true, Tensor y_pred)
{
return Metric.cosine_proximity(y_true, y_pred);
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/FalseNegatives.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class FalseNegatives
public class FalseNegatives : _ConfusionMatrixConditionCount
{ {
public FalseNegatives(float thresholds = 0.5F, string name = null, string dtype = null)
: base(Utils.MetricsUtils.ConfusionMatrix.FALSE_NEGATIVES, thresholds, name, dtype)
{
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/FalsePositives.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class FalsePositives
public class FalsePositives : _ConfusionMatrixConditionCount
{ {
public FalsePositives(float thresholds = 0.5F, string name = null, string dtype = null)
: base(Utils.MetricsUtils.ConfusionMatrix.FALSE_POSITIVES, thresholds, name, dtype)
{
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/Hinge.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class Hinge
public class Hinge : MeanMetricWrapper
{ {
public Hinge(string name = "hinge", string dtype = null)
: base(Losses.Loss.hinge, name, dtype)
{
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/LogCoshError.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class LogCoshError
public class LogCoshError : MeanMetricWrapper
{ {
public LogCoshError(string name = "logcosh", string dtype = null)
: base(Losses.Loss.logcosh, name, dtype)
{
}
} }
} }

+ 6
- 1
src/TensorFlowNET.Keras/Metrics/Mean.cs View File

@@ -4,7 +4,12 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class Mean
public class Mean : Reduce
{ {
public Mean(string name, string dtype = null)
: base(Reduction.MEAN, name, dtype)
{
}

} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/MeanAbsoluteError.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class MeanAbsoluteError
public class MeanAbsoluteError : MeanMetricWrapper
{ {
public MeanAbsoluteError(string name = "mean_absolute_error", string dtype = null)
: base(Losses.Loss.mean_absolute_error, name, dtype)
{
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/MeanAbsolutePercentageError.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class MeanAbsolutePercentageError
public class MeanAbsolutePercentageError : MeanMetricWrapper
{ {
public MeanAbsolutePercentageError(string name = "mean_absolute_percentage_error", string dtype = null)
: base(Losses.Loss.mean_absolute_percentage_error, name, dtype)
{
}
} }
} }

+ 16
- 1
src/TensorFlowNET.Keras/Metrics/MeanMetricWrapper.cs View File

@@ -1,10 +1,25 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class MeanMetricWrapper
public class MeanMetricWrapper : Mean
{ {
public MeanMetricWrapper(Func<Tensor, Tensor, Tensor> fn, string name, string dtype = null) : base(name, dtype)
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override Hashtable get_config()
{
throw new NotImplementedException();
}
} }
} }

+ 21
- 1
src/TensorFlowNET.Keras/Metrics/MeanRelativeError.cs View File

@@ -1,10 +1,30 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class MeanRelativeError
public class MeanRelativeError : Metric
{ {
public MeanRelativeError(Tensor normalizer, string name, string dtype) : base(name, dtype)
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override void update_state(Args args, KwArgs kwargs)
{
throw new NotImplementedException();
}

public override Hashtable get_config()
{
throw new NotImplementedException();
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/MeanSquaredError.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class MeanSquaredError
public class MeanSquaredError : MeanMetricWrapper
{ {
public MeanSquaredError(string name = "mean_squared_error", string dtype = null)
: base(Losses.Loss.mean_squared_error, name, dtype)
{
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/MeanSquaredLogarithmicError.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class MeanSquaredLogarithmicError
public class MeanSquaredLogarithmicError : MeanMetricWrapper
{ {
public MeanSquaredLogarithmicError(string name = "mean_squared_logarithmic_error", string dtype = null)
: base(Losses.Loss.mean_squared_logarithmic_error, name, dtype)
{
}
} }
} }

+ 24
- 0
src/TensorFlowNET.Keras/Metrics/Metric.cs View File

@@ -35,5 +35,29 @@ namespace Tensorflow.Keras.Metrics
public void add_weight(string name, TensorShape shape= null, VariableAggregation aggregation= VariableAggregation.Sum, public void add_weight(string name, TensorShape shape= null, VariableAggregation aggregation= VariableAggregation.Sum,
VariableSynchronization synchronization = VariableSynchronization.OnRead, Initializers.Initializer initializer= null, VariableSynchronization synchronization = VariableSynchronization.OnRead, Initializers.Initializer initializer= null,
string dtype= null) => throw new NotImplementedException(); string dtype= null) => throw new NotImplementedException();

public static Tensor accuracy(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor binary_accuracy(Tensor y_true, Tensor y_pred, float threshold = 0.5f) => throw new NotImplementedException();

public static Tensor categorical_accuracy(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor sparse_categorical_accuracy(Tensor y_true, Tensor y_pred) => throw new NotImplementedException();

public static Tensor top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5) => throw new NotImplementedException();

public static Tensor sparse_top_k_categorical_accuracy(Tensor y_true, Tensor y_pred, int k = 5) => throw new NotImplementedException();

public static Tensor cosine_proximity(Tensor y_true, Tensor y_pred, int axis = -1) => throw new NotImplementedException();

public static Metric clone_metric(Metric metric) => throw new NotImplementedException();

public static Metric[] clone_metrics(Metric[] metric) => throw new NotImplementedException();

public static string serialize(Metric metric) => throw new NotImplementedException();

public static Metric deserialize(string config, object custom_objects = null) => throw new NotImplementedException();

public static Metric get(object identifier) => throw new NotImplementedException();
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/Poisson.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class Poisson
public class Poisson : MeanMetricWrapper
{ {
public Poisson(string name = "logcosh", string dtype = null)
: base(Losses.Loss.logcosh, name, dtype)
{
}
} }
} }

+ 32
- 1
src/TensorFlowNET.Keras/Metrics/Precision.cs View File

@@ -1,10 +1,41 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class Precision
public class Precision : Metric
{ {
public Precision(float? thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype)
{
throw new NotImplementedException();
}

public Precision(float[] thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype)
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override void update_state(Args args, KwArgs kwargs)
{
throw new NotImplementedException();
}

public override void reset_states()
{
throw new NotImplementedException();
}

public override Hashtable get_config()
{
throw new NotImplementedException();
}

} }
} }

+ 16
- 1
src/TensorFlowNET.Keras/Metrics/PrecisionAtRecall.cs View File

@@ -1,10 +1,25 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class PrecisionAtRecall
public class PrecisionAtRecall : SensitivitySpecificityBase
{ {
public PrecisionAtRecall(float recall, int num_thresholds = 200, string name = null, string dtype = null) : base(recall, num_thresholds, name, dtype)
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override Hashtable get_config()
{
throw new NotImplementedException();
}
} }
} }

+ 32
- 1
src/TensorFlowNET.Keras/Metrics/Recall.cs View File

@@ -1,10 +1,41 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class Recall
public class Recall : Metric
{ {
public Recall(float? thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype)
{
throw new NotImplementedException();
}

public Recall(float[] thresholds = null, int? top_k = null, int? class_id = null, string name = null, string dtype = null) : base(name, dtype)
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override void update_state(Args args, KwArgs kwargs)
{
throw new NotImplementedException();
}

public override void reset_states()
{
throw new NotImplementedException();
}

public override Hashtable get_config()
{
throw new NotImplementedException();
}

} }
} }

+ 16
- 1
src/TensorFlowNET.Keras/Metrics/Reduce.cs View File

@@ -4,7 +4,22 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class Reduce
public class Reduce : Metric
{ {
public Reduce(string reduction, string name, string dtype= null)
: base(name, dtype)
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override void update_state(Args args, KwArgs kwargs)
{
throw new NotImplementedException();
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/RootMeanSquaredError.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class RootMeanSquaredError
public class RootMeanSquaredError : Mean
{ {
public RootMeanSquaredError(string name = "root_mean_squared_error", string dtype = null)
: base(name, dtype)
{
}
} }
} }

+ 16
- 1
src/TensorFlowNET.Keras/Metrics/SensitivityAtSpecificity.cs View File

@@ -1,10 +1,25 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class SensitivityAtSpecificity
public class SensitivityAtSpecificity : SensitivitySpecificityBase
{ {
public SensitivityAtSpecificity(float specificity, int num_thresholds = 200, string name = null, string dtype = null) : base(specificity, num_thresholds, name, dtype)
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override Hashtable get_config()
{
throw new NotImplementedException();
}
} }
} }

+ 20
- 1
src/TensorFlowNET.Keras/Metrics/SensitivitySpecificityBase.cs View File

@@ -4,7 +4,26 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class SensitivitySpecificityBase
public class SensitivitySpecificityBase : Metric
{ {
public SensitivitySpecificityBase(float value, int num_thresholds= 200, string name = null, string dtype = null) : base(name, dtype)
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override void update_state(Args args, KwArgs kwargs)
{
throw new NotImplementedException();
}

public override void reset_states()
{
throw new NotImplementedException();
}
} }
} }

+ 6
- 1
src/TensorFlowNET.Keras/Metrics/SparseCategoricalAccuracy.cs View File

@@ -4,7 +4,12 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class SparseCategoricalAccuracy
public class SparseCategoricalAccuracy : MeanMetricWrapper
{ {
public SparseCategoricalAccuracy(string name = "sparse_categorical_accuracy", string dtype = null)
: base(Metric.sparse_categorical_accuracy, name, dtype)
{
}

} }
} }

+ 11
- 1
src/TensorFlowNET.Keras/Metrics/SparseTopKCategoricalAccuracy.cs View File

@@ -4,7 +4,17 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class SparseTopKCategoricalAccuracy
public class SparseTopKCategoricalAccuracy : MeanMetricWrapper
{ {
public SparseTopKCategoricalAccuracy(int k = 5, string name = "sparse_top_k_categorical_accuracy", string dtype = null)
: base(Fn, name, dtype)
{
}

internal static Tensor Fn(Tensor y_true, Tensor y_pred)
{
return Metric.sparse_top_k_categorical_accuracy(y_true, y_pred);
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/SquaredHinge.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class SquaredHinge
public class SquaredHinge : MeanMetricWrapper
{ {
public SquaredHinge(string name = "squared_hinge", string dtype = null)
: base(Losses.Loss.squared_hinge, name, dtype)
{
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/Sum.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class Sum
public class Sum : Reduce
{ {
public Sum(string name, string dtype = null)
: base(Reduction.SUM, name, dtype)
{
}
} }
} }

+ 10
- 1
src/TensorFlowNET.Keras/Metrics/TopKCategoricalAccuracy.cs View File

@@ -4,7 +4,16 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class TopKCategoricalAccuracy
public class TopKCategoricalAccuracy : MeanMetricWrapper
{ {
public TopKCategoricalAccuracy(int k = 5, string name = "top_k_categorical_accuracy", string dtype = null)
: base(Fn, name, dtype)
{
}

internal static Tensor Fn(Tensor y_true, Tensor y_pred)
{
return Metric.top_k_categorical_accuracy(y_true, y_pred);
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/TrueNegatives.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class TrueNegatives
public class TrueNegatives : _ConfusionMatrixConditionCount
{ {
public TrueNegatives(float thresholds = 0.5F, string name = null, string dtype = null)
: base(Utils.MetricsUtils.ConfusionMatrix.TRUE_NEGATIVES, thresholds, name, dtype)
{
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/TruePositives.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class TruePositives
public class TruePositives : _ConfusionMatrixConditionCount
{ {
public TruePositives(float thresholds = 0.5F, string name = null, string dtype = null)
: base(Utils.MetricsUtils.ConfusionMatrix.TRUE_POSITIVES, thresholds, name, dtype)
{
}
} }
} }

+ 28
- 1
src/TensorFlowNET.Keras/Metrics/_ConfusionMatrixConditionCount.cs View File

@@ -1,10 +1,37 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using static Tensorflow.Keras.Utils.MetricsUtils;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class _ConfusionMatrixConditionCount
public class _ConfusionMatrixConditionCount : Metric
{ {
public _ConfusionMatrixConditionCount(string confusion_matrix_cond, float thresholds= 0.5f, string name= null, string dtype= null)
: base(name, dtype)
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override void update_state(Args args, KwArgs kwargs)
{
throw new NotImplementedException();
}

public override void reset_states()
{
throw new NotImplementedException();
}

public override Hashtable get_config()
{
throw new NotImplementedException();
}
} }
} }

+ 1
- 1
src/TensorFlowNET.Keras/Optimizer/Optimizer.cs View File

@@ -28,7 +28,7 @@ namespace Tensorflow.Keras


public static string serialize(Optimizer optimizer) => throw new NotImplementedException(); public static string serialize(Optimizer optimizer) => throw new NotImplementedException();


public static string deserialize(string config, object custom_objects = null) => throw new NotImplementedException();
public static Optimizer deserialize(string config, object custom_objects = null) => throw new NotImplementedException();


public static Optimizer get(object identifier) => throw new NotImplementedException(); public static Optimizer get(object identifier) => throw new NotImplementedException();




Loading…
Cancel
Save