From a4f03c22ec6391e7c22713fdaff4bd2b99da3afc Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 5 Jun 2019 07:28:27 -0500 Subject: [PATCH] added sparse_softmax_cross_entropy_with_logits --- src/TensorFlowNET.Core/APIs/tf.nn.cs | 11 ++++++++++ .../TextProcess/CnnTextClassification.cs | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs index bf0d4680..bfa71f31 100644 --- a/src/TensorFlowNET.Core/APIs/tf.nn.cs +++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs @@ -90,6 +90,17 @@ namespace Tensorflow public static Tensor softmax(Tensor logits, int axis = -1, string name = null) => gen_nn_ops.softmax(logits, name); + /// + /// Computes sparse softmax cross entropy between `logits` and `labels`. + /// + /// + /// + /// + /// + public static Tensor sparse_softmax_cross_entropy_with_logits(Tensor labels = null, + Tensor logits = null, string name = null) + => nn_ops.sparse_softmax_cross_entropy_with_logits(labels: labels, logits: logits, name: name); + public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null) => nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name); } diff --git a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs index 3060ded8..a96e0846 100644 --- a/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs +++ b/test/TensorFlowNET.Examples/TextProcess/CnnTextClassification.cs @@ -203,6 +203,27 @@ namespace TensorFlowNET.Examples var h_drop = tf.nn.dropout(h_pool_flat, keep_prob); }); + Tensor logits = null; + Tensor predictions = null; + with(tf.name_scope("output"), delegate + { + logits = tf.layers.dense(h_pool_flat, keep_prob); + predictions = tf.argmax(logits, -1, output_type: tf.int32); + }); + + with(tf.name_scope("loss"), delegate + { + var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y); + var loss = tf.reduce_mean(sscel); + var optimizer = tf.train.AdamOptimizer(learning_rate).minimize(loss, global_step: global_step); + }); + + with(tf.name_scope("accuracy"), delegate + { + var correct_predictions = tf.equal(predictions, y); + var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy"); + }); + return graph; }