Browse Source

Fix the keras.sparse_categorical_crossentropy. (#985)

pull/987/head
AsakusaRinne Yaohui Liu 2 years ago
parent
commit
b72c135a93
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
1 changed files with 7 additions and 0 deletions
  1. +7
    -0
      src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs

+ 7
- 0
src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs View File

@@ -14,6 +14,13 @@ namespace Tensorflow.Keras.Losses
{ {
target = tf.cast(target, dtype: TF_DataType.TF_INT64); target = tf.cast(target, dtype: TF_DataType.TF_INT64);


if (!from_logits)
{
var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype);
output = tf.clip_by_value(output, epsilon, 1 - epsilon);
output = tf.log(output);
}

// Try to adjust the shape so that rank of labels = rank of logits - 1. // Try to adjust the shape so that rank of labels = rank of logits - 1.
var output_shape = array_ops.shape_v2(output); var output_shape = array_ops.shape_v2(output);
var output_rank = output.shape.ndim; var output_rank = output.shape.ndim;


Loading…
Cancel
Save