From cb7e0170eb49ac926364d01c78f5f18d84dcd4a4 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sun, 17 Oct 2021 21:43:45 -0500 Subject: [PATCH] mean_absolute_percentage_error --- src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs | 4 ++++ src/TensorFlowNET.Keras/Engine/MetricsContainer.cs | 4 ++-- src/TensorFlowNET.Keras/Engine/Model.Compile.cs | 3 ++- src/TensorFlowNET.Keras/Metrics/MetricsApi.cs | 7 +++++++ 4 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs index 879a38e4..7876a990 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.Index.cs @@ -19,6 +19,7 @@ using System; using System.Collections.Generic; using System.Text; using static Tensorflow.Binding; +using System.Linq; namespace Tensorflow { @@ -62,5 +63,8 @@ namespace Tensorflow }); } } + + public Tensor this[params string[] slices] + => this[slices.Select(x => new Slice(x)).ToArray()]; } } diff --git a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs index 037703c8..790221f8 100644 --- a/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs +++ b/src/TensorFlowNET.Keras/Engine/MetricsContainer.cs @@ -77,9 +77,9 @@ namespace Tensorflow.Keras.Engine metric_obj = keras.metrics.categorical_accuracy; } else if(metric == "mean_absolute_error" || metric == "mae") - { metric_obj = keras.metrics.mean_absolute_error; - } + else if (metric == "mean_absolute_percentage_error" || metric == "mape") + metric_obj = keras.metrics.mean_absolute_percentage_error; else throw new NotImplementedException(""); diff --git a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs index 71bd2f38..7b051f1d 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Compile.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Compile.cs @@ -42,9 +42,10 @@ namespace Tensorflow.Keras.Engine _ => throw new NotImplementedException("") }; - var _loss = loss switch + ILossFunc _loss = loss switch { "mse" => new MeanSquaredError(), + "mae" => new MeanAbsoluteError(), _ => throw new NotImplementedException("") }; diff --git a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs index c8d54fc9..3d614e02 100644 --- a/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs +++ b/src/TensorFlowNET.Keras/Metrics/MetricsApi.cs @@ -46,5 +46,12 @@ namespace Tensorflow.Keras.Metrics y_true = math_ops.cast(y_true, y_pred.dtype); return keras.backend.mean(math_ops.abs(y_pred - y_true), axis: -1); } + + public Tensor mean_absolute_percentage_error(Tensor y_true, Tensor y_pred) + { + y_true = math_ops.cast(y_true, y_pred.dtype); + var diff = (y_true - y_pred) / math_ops.maximum(math_ops.abs(y_true), keras.backend.epsilon()); + return 100f * keras.backend.mean(math_ops.abs(diff), axis: -1); + } } }