Browse Source

Metrics methods skeletonizing completed.

tags/v0.20
Deepak Battini 5 years ago
parent
commit
7c587e40fe
8 changed files with 118 additions and 8 deletions
  1. +10
    -1
      src/TensorFlowNET.Keras/Metrics/BinaryCrossentropy.cs
  2. +10
    -1
      src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs
  3. +5
    -1
      src/TensorFlowNET.Keras/Metrics/KLDivergence.cs
  4. +25
    -1
      src/TensorFlowNET.Keras/Metrics/MeanIoU.cs
  5. +38
    -1
      src/TensorFlowNET.Keras/Metrics/MeanTensor.cs
  6. +10
    -1
      src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs
  7. +4
    -1
      src/TensorFlowNET.Keras/Metrics/SumOverBatchSize.cs
  8. +16
    -1
      src/TensorFlowNET.Keras/Metrics/SumOverBatchSizeMetricWrapper.cs

+ 10
- 1
src/TensorFlowNET.Keras/Metrics/BinaryCrossentropy.cs View File

@@ -4,7 +4,16 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class BinaryCrossentropy
public class BinaryCrossentropy : MeanMetricWrapper
{ {
public BinaryCrossentropy(string name = "binary_crossentropy", string dtype = null, bool from_logits = false, float label_smoothing = 0)
: base(Fn, name, dtype)
{
}

internal static Tensor Fn(Tensor y_true, Tensor y_pred)
{
return Losses.Loss.binary_crossentropy(y_true, y_pred);
}
} }
} }

+ 10
- 1
src/TensorFlowNET.Keras/Metrics/CategoricalCrossentropy.cs View File

@@ -4,7 +4,16 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class CategoricalCrossentropy
public class CategoricalCrossentropy : MeanMetricWrapper
{ {
public CategoricalCrossentropy(string name = "categorical_crossentropy", string dtype = null, bool from_logits = false, float label_smoothing = 0)
: base(Fn, name, dtype)
{
}

internal static Tensor Fn(Tensor y_true, Tensor y_pred)
{
return Losses.Loss.categorical_crossentropy(y_true, y_pred);
}
} }
} }

+ 5
- 1
src/TensorFlowNET.Keras/Metrics/KLDivergence.cs View File

@@ -4,7 +4,11 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class KLDivergence
public class KLDivergence : MeanMetricWrapper
{ {
public KLDivergence(string name = "kullback_leibler_divergence", string dtype = null)
: base(Losses.Loss.logcosh, name, dtype)
{
}
} }
} }

+ 25
- 1
src/TensorFlowNET.Keras/Metrics/MeanIoU.cs View File

@@ -1,10 +1,34 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class MeanIoU
public class MeanIoU : Metric
{ {
public MeanIoU(int num_classes, string name, string dtype) : base(name, dtype)
{
}

public override void reset_states()
{
throw new NotImplementedException();
}

public override Hashtable get_config()
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override void update_state(Args args, KwArgs kwargs)
{
throw new NotImplementedException();
}
} }
} }

+ 38
- 1
src/TensorFlowNET.Keras/Metrics/MeanTensor.cs View File

@@ -4,7 +4,44 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class MeanTensor
public class MeanTensor : Metric
{ {
public int total
{
get
{
throw new NotImplementedException();
}
}

public int count
{
get
{
throw new NotImplementedException();
}
}

public MeanTensor(int num_classes, string name = "mean_tensor", string dtype) : base(name, dtype)
{
}


private void _build(TensorShape shape) => throw new NotImplementedException();

public override void reset_states()
{
throw new NotImplementedException();
}

public override Tensor result()
{
throw new NotImplementedException();
}

public override void update_state(Args args, KwArgs kwargs)
{
throw new NotImplementedException();
}
} }
} }

+ 10
- 1
src/TensorFlowNET.Keras/Metrics/SparseCategoricalCrossentropy.cs View File

@@ -4,7 +4,16 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class SparseCategoricalCrossentropy
public class SparseCategoricalCrossentropy : MeanMetricWrapper
{ {
public SparseCategoricalCrossentropy(string name = "sparse_categorical_crossentropy", string dtype = null, bool from_logits = false, int axis = -1)
: base(Fn, name, dtype)
{
}

internal static Tensor Fn(Tensor y_true, Tensor y_pred)
{
return Losses.Loss.sparse_categorical_crossentropy(y_true, y_pred);
}
} }
} }

+ 4
- 1
src/TensorFlowNET.Keras/Metrics/SumOverBatchSize.cs View File

@@ -4,7 +4,10 @@ using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class SumOverBatchSize
public class SumOverBatchSize : Reduce
{ {
public SumOverBatchSize(string name = "sum_over_batch_size", string dtype = null) : base(Reduction.SUM_OVER_BATCH_SIZE, name, dtype)
{
}
} }
} }

+ 16
- 1
src/TensorFlowNET.Keras/Metrics/SumOverBatchSizeMetricWrapper.cs View File

@@ -1,10 +1,25 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;


namespace Tensorflow.Keras.Metrics namespace Tensorflow.Keras.Metrics
{ {
class SumOverBatchSizeMetricWrapper
public class SumOverBatchSizeMetricWrapper : SumOverBatchSize
{ {
public SumOverBatchSizeMetricWrapper(Func<Tensor, Tensor, Tensor> fn, string name, string dtype = null)
{
throw new NotImplementedException();
}

public override void update_state(Args args, KwArgs kwargs)
{
throw new NotImplementedException();
}

public override Hashtable get_config()
{
throw new NotImplementedException();
}
} }
} }

Loading…
Cancel
Save