Browse Source

MeanSquaredError.cs

pull/678/head
dataangel GitHub 5 years ago
parent
commit
bbefacb24b
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 0 deletions
  1. +26
    -0
      src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs

+ 26
- 0
src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs View File

@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Losses
{
public class MeanSquaredError : LossFunctionWrapper, ILossFunc
{
public MeanSquaredError(
string reduction = ReductionV2.AUTO,
string name = "mean_squared_error") :
base(reduction: reduction,
name: name)
{
}

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1);
}
}
}

Loading…
Cancel
Save