|
|
|
@@ -0,0 +1,26 @@ |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.Text; |
|
|
|
using static Tensorflow.Binding; |
|
|
|
using static Tensorflow.KerasApi; |
|
|
|
|
|
|
|
namespace Tensorflow.Keras.Losses |
|
|
|
{ |
|
|
|
public class MeanSquaredError : LossFunctionWrapper, ILossFunc |
|
|
|
{ |
|
|
|
public MeanSquaredError( |
|
|
|
string reduction = ReductionV2.AUTO, |
|
|
|
string name = "mean_squared_error") : |
|
|
|
base(reduction: reduction, |
|
|
|
name: name) |
|
|
|
{ |
|
|
|
} |
|
|
|
|
|
|
|
public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) |
|
|
|
{
|
|
|
|
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
|
|
|
|
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
|
|
|
|
return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); |
|
|
|
} |
|
|
|
} |
|
|
|
} |