|
|
|
@@ -14,6 +14,13 @@ namespace Tensorflow.Keras.Losses |
|
|
|
{ |
|
|
|
target = tf.cast(target, dtype: TF_DataType.TF_INT64); |
|
|
|
|
|
|
|
if (!from_logits) |
|
|
|
{ |
|
|
|
var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype); |
|
|
|
output = tf.clip_by_value(output, epsilon, 1 - epsilon); |
|
|
|
output = tf.log(output); |
|
|
|
} |
|
|
|
|
|
|
|
// Try to adjust the shape so that rank of labels = rank of logits - 1. |
|
|
|
var output_shape = array_ops.shape_v2(output); |
|
|
|
var output_rank = output.shape.ndim; |
|
|
|
|