| @@ -3,6 +3,6 @@ | |||||
| public interface ILossFunc | public interface ILossFunc | ||||
| { | { | ||||
| string Reduction { get; } | string Reduction { get; } | ||||
| Tensor Call(Tensor y_true, Tensor y_pred); | |||||
| Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null); | |||||
| } | } | ||||
| } | } | ||||
| @@ -31,10 +31,11 @@ namespace Tensorflow.Keras.Losses | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| public Tensor Call(Tensor y_true, Tensor y_pred) | |||||
| public Tensor Call(Tensor y_true, Tensor y_pred, Tensor sample_weight = null) | |||||
| { | { | ||||
| var losses = Apply(y_true, y_pred, from_logits: from_logits); | var losses = Apply(y_true, y_pred, from_logits: from_logits); | ||||
| return losses_utils.compute_weighted_loss(losses, reduction: ReductionV2.SUM_OVER_BATCH_SIZE); | |||||
| return losses_utils.compute_weighted_loss(losses, reduction: this.reduction == null?ReductionV2.SUM_OVER_BATCH_SIZE : this.reduction, sample_weight: sample_weight); | |||||
| } | } | ||||
| void _set_name_scope() | void _set_name_scope() | ||||
| @@ -7,5 +7,8 @@ | |||||
| public ILossFunc CategoricalCrossentropy(bool from_logits = false) | public ILossFunc CategoricalCrossentropy(bool from_logits = false) | ||||
| => new CategoricalCrossentropy(from_logits: from_logits); | => new CategoricalCrossentropy(from_logits: from_logits); | ||||
| public ILossFunc MeanSquaredError(string reduction = null) | |||||
| => new MeanSquaredError(reduction: reduction); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,26 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| using static Tensorflow.KerasApi; | |||||
| namespace Tensorflow.Keras.Losses | |||||
| { | |||||
| public class MeanSquaredError : LossFunctionWrapper, ILossFunc | |||||
| { | |||||
| public MeanSquaredError( | |||||
| string reduction = ReductionV2.AUTO, | |||||
| string name = "mean_squared_error") : | |||||
| base(reduction: reduction, | |||||
| name: name) | |||||
| { | |||||
| } | |||||
| public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) | |||||
| { | |||||
| Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); | |||||
| Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); | |||||
| return gen_math_ops.mean(gen_math_ops.squared_difference(y_pred_dispatch, y_true_cast), axis: -1); | |||||
| } | |||||
| } | |||||
| } | |||||