|
|
|
@@ -8,6 +8,7 @@ using System.Text; |
|
|
|
using NumSharp; |
|
|
|
using Tensorflow; |
|
|
|
using Tensorflow.Keras.Engine; |
|
|
|
using Tensorflow.Sessions; |
|
|
|
using TensorFlowNET.Examples.Text.cnn_models; |
|
|
|
using TensorFlowNET.Examples.TextClassification; |
|
|
|
using TensorFlowNET.Examples.Utility; |
|
|
|
@@ -91,7 +92,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification |
|
|
|
foreach (var (x_batch, y_batch, total) in train_batches) |
|
|
|
{ |
|
|
|
i++; |
|
|
|
var train_feed_dict = new Hashtable |
|
|
|
var train_feed_dict = new FeedDict |
|
|
|
{ |
|
|
|
[model_x] = x_batch, |
|
|
|
[model_y] = y_batch, |
|
|
|
@@ -113,25 +114,26 @@ namespace TensorFlowNET.Examples.CnnTextClassification |
|
|
|
|
|
|
|
if (step % 100 == 0) |
|
|
|
{ |
|
|
|
continue;
|
|
|
|
// # Test accuracy with validation data for each epoch.
|
|
|
|
var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); |
|
|
|
var (sum_accuracy, cnt) = (0, 0); |
|
|
|
var (sum_accuracy, cnt) = (0.0f, 0); |
|
|
|
foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches) |
|
|
|
{ |
|
|
|
// valid_feed_dict = { |
|
|
|
// model.x: valid_x_batch, |
|
|
|
// model.y: valid_y_batch, |
|
|
|
// model.is_training: False |
|
|
|
// } |
|
|
|
|
|
|
|
// accuracy = sess.run(model.accuracy, feed_dict = valid_feed_dict) |
|
|
|
// sum_accuracy += accuracy |
|
|
|
// cnt += 1 |
|
|
|
var valid_feed_dict = new FeedDict |
|
|
|
{ |
|
|
|
[model_x] = valid_x_batch, |
|
|
|
[model_y] = valid_y_batch, |
|
|
|
[is_training] = false |
|
|
|
}; |
|
|
|
var result1 = sess.run(accuracy, valid_feed_dict); |
|
|
|
float accuracy_value = result1; |
|
|
|
sum_accuracy += accuracy_value; |
|
|
|
cnt += 1; |
|
|
|
} |
|
|
|
// valid_accuracy = sum_accuracy / cnt |
|
|
|
|
|
|
|
// print("\nValidation Accuracy = {1}\n".format(step // num_batches_per_epoch, sum_accuracy / cnt)) |
|
|
|
var valid_accuracy = sum_accuracy / cnt; |
|
|
|
|
|
|
|
print($"\nValidation Accuracy = {valid_accuracy}\n"); |
|
|
|
|
|
|
|
// # Save model |
|
|
|
// if valid_accuracy > max_accuracy: |
|
|
|
|