Browse Source

更新Mse

pull/684/head
dataangel 5 years ago
parent
commit
3d147bd4fd
4 changed files with 33 additions and 3 deletions
  1. +1
    -1
      src/TensorFlowNET.Keras/Losses/ILossFunc.cs
  2. +3
    -2
      src/TensorFlowNET.Keras/Losses/Loss.cs
  3. +3
    -0
      src/TensorFlowNET.Keras/Losses/LossesApi.cs
  4. +26
    -0
      src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs

+ 1
- 1
src/TensorFlowNET.Keras/Losses/ILossFunc.cs View File

@@ -3,6 +3,6 @@
public interface ILossFunc public interface ILossFunc
{ {
string Reduction { get; } string Reduction { get; }
Tensor Call(Tensor y_true, Tensor y_pred);
Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null);
} }
} }

+ 3
- 2
src/TensorFlowNET.Keras/Losses/Loss.cs View File

@@ -31,10 +31,11 @@ namespace Tensorflow.Keras.Losses
throw new NotImplementedException(""); throw new NotImplementedException("");
} }


public Tensor Call(Tensor y_true, Tensor y_pred)
public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null)
{ {
var losses = Apply(y_true, y_pred, from_logits: from_logits); var losses = Apply(y_true, y_pred, from_logits: from_logits);
return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE);

return losses_utils.compute_weighted_loss(losses, reduction: this.reduction == null?ReductionV2.SUM_OVER_BATCH_SIZE : this.reduction, sample_weight: sample_weight);
} }


void _set_name_scope() void _set_name_scope()


+ 3
- 0
src/TensorFlowNET.Keras/Losses/LossesApi.cs View File

@@ -7,5 +7,8 @@


public ILossFunc CategoricalCrossentropy(bool from_logits = false) public ILossFunc CategoricalCrossentropy(bool from_logits = false)
=> new CategoricalCrossentropy(from_logits: from_logits); => new CategoricalCrossentropy(from_logits: from_logits);
public ILossFunc MeanSquaredError(string reduction = null)
=> new MeanSquaredError(reduction: reduction);
} }
} }

+ 26
- 0
src/TensorFlowNET.Keras/Losses/MeanSquaredError.cs View File

@@ -0,0 +1,26 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;

namespace Tensorflow.Keras.Losses
{
public class MeanSquaredError : LossFunctionWrapper, ILossFunc
{
public MeanSquaredError(
string reduction = ReductionV2.AUTO,
string name = "mean_squared_error") :
base(reduction: reduction,
name: name)
{
}

public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1)
{
Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred);
Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype);
return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1);
}
}
}

Loading…
Cancel
Save