diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index 587aca5b..22339226 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -19,7 +19,7 @@ namespace Tensorflow public BaseSession(string target = "", Graph graph = null) { - if(graph is null) + if (graph is null) { _graph = ops.get_default_graph(); } @@ -41,9 +41,9 @@ namespace Tensorflow return _run(fetches, feed_dict); } - public virtual NDArray run(ITensorOrOperation[] fetches, Hashtable feed_dict = null) + public virtual NDArray run(object fetches, Hashtable feed_dict = null) { - var feed_items = feed_dict == null ? new FeedItem[0] : + var feed_items = feed_dict == null ? new FeedItem[0] : feed_dict.Keys.OfType().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); return _run(fetches, feed_items); } @@ -98,7 +98,7 @@ namespace Tensorflow feed_dict_tensor[subfeed_t] = (NDArray)val; break; case bool val: - feed_dict_tensor[subfeed_t] = (NDArray) val; + feed_dict_tensor[subfeed_t] = (NDArray)val; break; case bool[] val: feed_dict_tensor[subfeed_t] = (NDArray)val; @@ -106,8 +106,8 @@ namespace Tensorflow default: Console.WriteLine($"can't handle data type of subfeed_val"); throw new NotImplementedException("_run subfeed"); - } - + } + feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); } } @@ -146,9 +146,9 @@ namespace Tensorflow /// private NDArray[] _do_run(List target_list, List fetch_list, Dictionary feed_dict) { - var feeds = feed_dict.Select(x => + var feeds = feed_dict.Select(x => { - if(x.Key is Tensor tensor) + if (x.Key is Tensor tensor) { switch (x.Value) { diff --git a/src/TensorFlowNET.Core/Sessions/FeedDict.cs b/src/TensorFlowNET.Core/Sessions/FeedDict.cs new file mode 100644 index 00000000..95e51b06 --- /dev/null +++ b/src/TensorFlowNET.Core/Sessions/FeedDict.cs @@ -0,0 +1,11 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Sessions +{ + public class FeedDict : Hashtable + { + } +} diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index a8eb01c9..73c74e3f 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -8,6 +8,7 @@ using System.Text; using NumSharp; using Tensorflow; using Tensorflow.Keras.Engine; +using Tensorflow.Sessions; using TensorFlowNET.Examples.Text.cnn_models; using TensorFlowNET.Examples.TextClassification; using TensorFlowNET.Examples.Utility; @@ -91,7 +92,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification foreach (var (x_batch, y_batch, total) in train_batches) { i++; - var train_feed_dict = new Hashtable + var train_feed_dict = new FeedDict { [model_x] = x_batch, [model_y] = y_batch, @@ -113,25 +114,26 @@ namespace TensorFlowNET.Examples.CnnTextClassification if (step % 100 == 0) { - continue; // # Test accuracy with validation data for each epoch. var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); - var (sum_accuracy, cnt) = (0, 0); + var (sum_accuracy, cnt) = (0.0f, 0); foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches) { - // valid_feed_dict = { - // model.x: valid_x_batch, - // model.y: valid_y_batch, - // model.is_training: False - // } - - // accuracy = sess.run(model.accuracy, feed_dict = valid_feed_dict) - // sum_accuracy += accuracy - // cnt += 1 + var valid_feed_dict = new FeedDict + { + [model_x] = valid_x_batch, + [model_y] = valid_y_batch, + [is_training] = false + }; + var result1 = sess.run(accuracy, valid_feed_dict); + float accuracy_value = result1; + sum_accuracy += accuracy_value; + cnt += 1; } - // valid_accuracy = sum_accuracy / cnt - // print("\nValidation Accuracy = {1}\n".format(step // num_batches_per_epoch, sum_accuracy / cnt)) + var valid_accuracy = sum_accuracy / cnt; + + print($"\nValidation Accuracy = {valid_accuracy}\n"); // # Save model // if valid_accuracy > max_accuracy: