Browse Source

Fix the bug of non-convergence when use SparseCategoricalCrossentropy

tags/v0.100.5-BERT-load
Wanglongzhi2001 2 years ago
parent
commit
5506f00906
1 changed files with 6 additions and 2 deletions
  1. +6
    -2
      src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs

+ 6
- 2
src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs View File

@@ -4,17 +4,21 @@ namespace Tensorflow.Keras.Losses
{
public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc
{
private bool _from_logits = false;
public SparseCategoricalCrossentropy(
bool from_logits = false,
string reduction = null,
string name = null) :
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ }
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name)
{
_from_logits = from_logits;
}

public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1)
{
target = tf.cast(target, dtype: TF_DataType.TF_INT64);

if (!from_logits)
if (!_from_logits)
{
var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype);
output = tf.clip_by_value(output, epsilon, 1 - epsilon);


Loading…
Cancel
Save