using System; using System.Collections.Generic; using System.Text; using Tensorflow.Operations; using static Tensorflow.Binding; using static Tensorflow.KerasApi; namespace Tensorflow.Keras.Losses { public class LogCosh : LossFunctionWrapper, ILossFunc { public LogCosh( string reduction = null, string name = null) : base(reduction: reduction, name: name == null ? "log_cosh" : name){ } public override Tensor Apply(Tensor y_true = null, Tensor y_pred =null, bool from_logits = false, int axis = -1) { Tensor y_pred_dispatch = ops.convert_to_tensor(y_pred); Tensor y_true_cast = gen_math_ops.cast(y_true, y_pred_dispatch.dtype); Tensor x = y_pred_dispatch - y_true_cast; return gen_math_ops.mean(x + gen_nn_ops.softplus(-2.0 * x) - math_ops.cast(math_ops.log(tf.Variable(2.0)), x.dtype), ops.convert_to_tensor(-1)); } } }