| @@ -4,17 +4,21 @@ namespace Tensorflow.Keras.Losses | |||||
| { | { | ||||
| public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc | public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc | ||||
| { | { | ||||
| private bool _from_logits = false; | |||||
| public SparseCategoricalCrossentropy( | public SparseCategoricalCrossentropy( | ||||
| bool from_logits = false, | bool from_logits = false, | ||||
| string reduction = null, | string reduction = null, | ||||
| string name = null) : | string name = null) : | ||||
| base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ } | |||||
| base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name) | |||||
| { | |||||
| _from_logits = from_logits; | |||||
| } | |||||
| public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) | public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) | ||||
| { | { | ||||
| target = tf.cast(target, dtype: TF_DataType.TF_INT64); | target = tf.cast(target, dtype: TF_DataType.TF_INT64); | ||||
| if (!from_logits) | |||||
| if (!_from_logits) | |||||
| { | { | ||||
| var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype); | var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype); | ||||
| output = tf.clip_by_value(output, epsilon, 1 - epsilon); | output = tf.clip_by_value(output, epsilon, 1 - epsilon); | ||||