| @@ -0,0 +1,65 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Keras.Losses; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace TensorFlowNET.UnitTest.Keras | |||||
| { | |||||
| [TestClass] | |||||
| public class MeanSquaredErrorTest | |||||
| { | |||||
| //https://keras.io/api/losses/regression_losses/#meansquarederror-class | |||||
| private NDArray y_true = new double[,] { { 0.0, 1.0 }, { 0.0, 0.0 } }; | |||||
| private NDArray y_pred = new double[,] { { 1.0, 1.0 }, { 1.0, 0.0 } }; | |||||
| [TestMethod] | |||||
| public void Mse_Double() | |||||
| { | |||||
| var mse = keras.losses.MeanSquaredError(); | |||||
| var call = mse.Call(y_true, y_pred); | |||||
| Assert.AreEqual((NDArray)0.5, call.numpy()) ; | |||||
| } | |||||
| [TestMethod] | |||||
| public void Mse_Float() | |||||
| { | |||||
| NDArray y_true_float = new float[,] { { 0.0f, 1.0f }, { 0.0f, 0.0f } }; | |||||
| NDArray y_pred_float = new float[,] { { 1.0f, 1.0f }, { 1.0f, 0.0f } }; | |||||
| var mse = keras.losses.MeanSquaredError(); | |||||
| var call = mse.Call(y_true_float, y_pred_float); | |||||
| Assert.AreEqual((NDArray)0.5, call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Mse_Sample_Weight() | |||||
| { | |||||
| var mse = keras.losses.MeanSquaredError(); | |||||
| var call = mse.Call(y_true, y_pred, sample_weight: (NDArray)new double[] { 0.7, 0.3 }); | |||||
| Assert.AreEqual((NDArray)0.25, call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Mse_Reduction_SUM() | |||||
| { | |||||
| var mse = keras.losses.MeanSquaredError(reduction: Reduction.SUM); | |||||
| var call = mse.Call(y_true, y_pred); | |||||
| Assert.AreEqual((NDArray)1.0, call.numpy()); | |||||
| } | |||||
| [TestMethod] | |||||
| public void Mse_Reduction_NONE() | |||||
| { | |||||
| var mse = keras.losses.MeanSquaredError(reduction: Reduction.NONE); | |||||
| var call = mse.Call(y_true, y_pred); | |||||
| Assert.AreEqual((NDArray)new double[] { 0.5, 0.5 }, call.numpy()); | |||||
| } | |||||
| } | |||||
| } | |||||