diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index bf0d4680..bfa71f31 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -90,6 +90,17 @@ namespace Tensorflow
public static Tensor softmax(Tensor logits, int axis = -1, string name = null)
=> gen_nn_ops.softmax(logits, name);
+ ///
+ /// Computes sparse softmax cross entropy between `logits` and `labels`.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor sparse_softmax_cross_entropy_with_logits(Tensor labels = null,
+ Tensor logits = null, string name = null)
+ => nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, logits: logits, name: name);
+
public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
}
diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
index 3060ded8..a96e0846 100644
--- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
+++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs
@@ -203,6 +203,27 @@ namespace TensorFlowNET.Examples
var h_drop = tf.nn.dropout(h_pool_flat, keep_prob);
});
+ Tensor logits = null;
+ Tensor predictions = null;
+ with(tf.name_scope("output"), delegate
+ {
+ logits = tf.layers.dense(h_pool_flat, keep_prob);
+ predictions = tf.argmax(logits, -1, output_type: tf.int32);
+ });
+
+ with(tf.name_scope("loss"), delegate
+ {
+ var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y);
+ var loss = tf.reduce_mean(sscel);
+ var optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step: global_step);
+ });
+
+ with(tf.name_scope("accuracy"), delegate
+ {
+ var correct_predictions = tf.equal(predictions, y);
+ var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy");
+ });
+
return graph;
}