diff --git a/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs new file mode 100644 index 00000000..884808f4 --- /dev/null +++ b/src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; + +namespace Tensorflow.Keras.Losses +{ + public class MeanSquaredError : LossFunctionWrapper, ILossFunc + { + public MeanSquaredError( + string reduction = ReductionV2.AUTO, + string name = "mean_squared_error") : + base(reduction: reduction, + name: name) + { + } + + public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) + { + Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); + Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); + return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); + } + } +}