| @@ -0,0 +1,26 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Losses | |||||
| { | |||||
| public class MeanSquaredError : LossFunctionWrapper, ILossFunc | |||||
| { | |||||
| public MeanSquaredError( | |||||
| string reduction = ReductionV2.AUTO, | |||||
| string name = "mean_squared_error") : | |||||
| base(reduction: reduction, | |||||
| name: name) | |||||
| { | |||||
| } | |||||
| public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) | |||||
| { | |||||
| Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); | |||||
| Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||||
| return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); | |||||
| } | |||||
| } | |||||
| } | |||||