|
|
|
@@ -4,17 +4,21 @@ namespace Tensorflow.Keras.Losses |
|
|
|
{ |
|
|
|
public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc |
|
|
|
{ |
|
|
|
private bool _from_logits = false; |
|
|
|
public SparseCategoricalCrossentropy( |
|
|
|
bool from_logits = false, |
|
|
|
string reduction = null, |
|
|
|
string name = null) : |
|
|
|
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ } |
|
|
|
base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name) |
|
|
|
{ |
|
|
|
_from_logits = from_logits; |
|
|
|
} |
|
|
|
|
|
|
|
public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) |
|
|
|
{ |
|
|
|
target = tf.cast(target, dtype: TF_DataType.TF_INT64); |
|
|
|
|
|
|
|
if (!from_logits) |
|
|
|
if (!_from_logits) |
|
|
|
{ |
|
|
|
var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype); |
|
|
|
output = tf.clip_by_value(output, epsilon, 1 - epsilon); |
|
|
|
|