| @@ -19,7 +19,7 @@ namespace Tensorflow | |||||
| public BaseSession(string target = "", Graph graph = null) | public BaseSession(string target = "", Graph graph = null) | ||||
| { | { | ||||
| if(graph is null) | |||||
| if (graph is null) | |||||
| { | { | ||||
| _graph = ops.get_default_graph(); | _graph = ops.get_default_graph(); | ||||
| } | } | ||||
| @@ -41,9 +41,9 @@ namespace Tensorflow | |||||
| return _run(fetches, feed_dict); | return _run(fetches, feed_dict); | ||||
| } | } | ||||
| public virtual NDArray run(ITensorOrOperation[] fetches, Hashtable feed_dict = null) | |||||
| public virtual NDArray run(object fetches, Hashtable feed_dict = null) | |||||
| { | { | ||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| var feed_items = feed_dict == null ? new FeedItem[0] : | |||||
| feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | feed_dict.Keys.OfType<object>().Select(key => new FeedItem(key, feed_dict[key])).ToArray(); | ||||
| return _run(fetches, feed_items); | return _run(fetches, feed_items); | ||||
| } | } | ||||
| @@ -98,7 +98,7 @@ namespace Tensorflow | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | feed_dict_tensor[subfeed_t] = (NDArray)val; | ||||
| break; | break; | ||||
| case bool val: | case bool val: | ||||
| feed_dict_tensor[subfeed_t] = (NDArray) val; | |||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | |||||
| break; | break; | ||||
| case bool[] val: | case bool[] val: | ||||
| feed_dict_tensor[subfeed_t] = (NDArray)val; | feed_dict_tensor[subfeed_t] = (NDArray)val; | ||||
| @@ -106,8 +106,8 @@ namespace Tensorflow | |||||
| default: | default: | ||||
| Console.WriteLine($"can't handle data type of subfeed_val"); | Console.WriteLine($"can't handle data type of subfeed_val"); | ||||
| throw new NotImplementedException("_run subfeed"); | throw new NotImplementedException("_run subfeed"); | ||||
| } | |||||
| } | |||||
| feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | feed_map[subfeed_t.name] = (subfeed_t, subfeed_val); | ||||
| } | } | ||||
| } | } | ||||
| @@ -146,9 +146,9 @@ namespace Tensorflow | |||||
| /// </returns> | /// </returns> | ||||
| private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | private NDArray[] _do_run(List<Operation> target_list, List<Tensor> fetch_list, Dictionary<object, object> feed_dict) | ||||
| { | { | ||||
| var feeds = feed_dict.Select(x => | |||||
| var feeds = feed_dict.Select(x => | |||||
| { | { | ||||
| if(x.Key is Tensor tensor) | |||||
| if (x.Key is Tensor tensor) | |||||
| { | { | ||||
| switch (x.Value) | switch (x.Value) | ||||
| { | { | ||||
| @@ -0,0 +1,11 @@ | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Sessions | |||||
| { | |||||
| public class FeedDict : Hashtable | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -8,6 +8,7 @@ using System.Text; | |||||
| using NumSharp; | using NumSharp; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using Tensorflow.Sessions; | |||||
| using TensorFlowNET.Examples.Text.cnn_models; | using TensorFlowNET.Examples.Text.cnn_models; | ||||
| using TensorFlowNET.Examples.TextClassification; | using TensorFlowNET.Examples.TextClassification; | ||||
| using TensorFlowNET.Examples.Utility; | using TensorFlowNET.Examples.Utility; | ||||
| @@ -91,7 +92,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| foreach (var (x_batch, y_batch, total) in train_batches) | foreach (var (x_batch, y_batch, total) in train_batches) | ||||
| { | { | ||||
| i++; | i++; | ||||
| var train_feed_dict = new Hashtable | |||||
| var train_feed_dict = new FeedDict | |||||
| { | { | ||||
| [model_x] = x_batch, | [model_x] = x_batch, | ||||
| [model_y] = y_batch, | [model_y] = y_batch, | ||||
| @@ -113,25 +114,26 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| if (step % 100 == 0) | if (step % 100 == 0) | ||||
| { | { | ||||
| continue; | |||||
| // # Test accuracy with validation data for each epoch. | // # Test accuracy with validation data for each epoch. | ||||
| var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); | var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); | ||||
| var (sum_accuracy, cnt) = (0, 0); | |||||
| var (sum_accuracy, cnt) = (0.0f, 0); | |||||
| foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches) | foreach (var (valid_x_batch, valid_y_batch, total_validation_batches) in valid_batches) | ||||
| { | { | ||||
| // valid_feed_dict = { | |||||
| // model.x: valid_x_batch, | |||||
| // model.y: valid_y_batch, | |||||
| // model.is_training: False | |||||
| // } | |||||
| // accuracy = sess.run(model.accuracy, feed_dict = valid_feed_dict) | |||||
| // sum_accuracy += accuracy | |||||
| // cnt += 1 | |||||
| var valid_feed_dict = new FeedDict | |||||
| { | |||||
| [model_x] = valid_x_batch, | |||||
| [model_y] = valid_y_batch, | |||||
| [is_training] = false | |||||
| }; | |||||
| var result1 = sess.run(accuracy, valid_feed_dict); | |||||
| float accuracy_value = result1; | |||||
| sum_accuracy += accuracy_value; | |||||
| cnt += 1; | |||||
| } | } | ||||
| // valid_accuracy = sum_accuracy / cnt | |||||
| // print("\nValidation Accuracy = {1}\n".format(step // num_batches_per_epoch, sum_accuracy / cnt)) | |||||
| var valid_accuracy = sum_accuracy / cnt; | |||||
| print($"\nValidation Accuracy = {valid_accuracy}\n"); | |||||
| // # Save model | // # Save model | ||||
| // if valid_accuracy > max_accuracy: | // if valid_accuracy > max_accuracy: | ||||