From 47a7ac8017b1dfe5c832d7957f69b030e5daca40 Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Sat, 11 May 2019 15:19:43 +0200 Subject: [PATCH] TextClassification: delete old cached meta graph which has a bug --- .../TextProcess/TextClassificationTrain.cs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index 891ad81f..9d24e3c6 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -32,6 +32,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn private const int CHAR_MAX_LEN = 1014; + private const int WORD_MAX_LEN = 1014; private const int NUM_CLASS = 2; private const int BATCH_SIZE = 64; private const int NUM_EPOCHS = 10; @@ -58,6 +59,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification Console.WriteLine("\tDONE "); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); + Console.WriteLine("Training set size: " + train_x.shape[0]); + Console.WriteLine("Test set size: " + valid_x.shape[0]); Console.WriteLine("Import graph..."); var meta_file = model_name + ".meta"; @@ -74,7 +77,6 @@ namespace TensorFlowNET.Examples.CnnTextClassification Tensor model_x = graph.get_operation_by_name("x"); Tensor model_y = graph.get_operation_by_name("y"); Tensor loss = graph.get_operation_by_name("loss/value"); - //var optimizer_nodes = graph._nodes_by_name.Keys.Where(key => key.Contains("optimizer")).ToArray(); Tensor optimizer = graph.get_operation_by_name("loss/optimizer"); Tensor global_step = graph.get_operation_by_name("global_step"); Tensor accuracy = graph.get_operation_by_name("accuracy/value"); @@ -92,7 +94,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification // original python: //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); - //loss_value = result[2]; + loss_value = result[2]; var step = result[1]; if (step % 10 == 0) { @@ -177,7 +179,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) { - var num_batches_per_epoch = (len(inputs) - 1) / batch_size; + var num_batches_per_epoch = (len(inputs) - 1) / batch_size +1; var total_batches = num_batches_per_epoch * num_epochs; foreach (var epoch in range(num_epochs)) { @@ -202,8 +204,13 @@ namespace TensorFlowNET.Examples.CnnTextClassification { // download graph meta data var meta_file = model_name + ".meta"; - if (File.GetLastWriteTime(meta_file) < new DateTime(2019,05,11)) // delete old cached file which contains errors - File.Delete(meta_file); + var meta_path = Path.Combine("graph", meta_file); + if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11)) + { + // delete old cached file which contains errors + Console.WriteLine("Discarding cached file: " + meta_path); + File.Delete(meta_path); + } url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; Web.Download(url, "graph", meta_file); }