diff --git a/graph/vd_cnn.meta b/graph/vd_cnn.meta index 1c676f52..b857fc6c 100644 Binary files a/graph/vd_cnn.meta and b/graph/vd_cnn.meta differ diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index d9877d8c..891ad81f 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -52,38 +52,37 @@ namespace TensorFlowNET.Examples.CnnTextClassification protected virtual bool RunWithImportedGraph(Session sess, Graph graph) { + var stopwatch = Stopwatch.StartNew(); Console.WriteLine("Building dataset..."); var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit=null); - Console.WriteLine("\tDONE"); + Console.WriteLine("\tDONE "); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); Console.WriteLine("Import graph..."); var meta_file = model_name + ".meta"; tf.train.import_meta_graph(Path.Join("graph", meta_file)); - Console.WriteLine("\tDONE"); - // definitely necessary, otherwize will get the exception of "use uninitialized variable" + Console.WriteLine("\tDONE " + stopwatch.Elapsed); + sess.run(tf.global_variables_initializer()); var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); - var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1 + var num_batches_per_epoch = (len(train_x) - 1) / BATCH_SIZE + 1; double max_accuracy = 0; Tensor is_training = graph.get_operation_by_name("is_training"); Tensor model_x = graph.get_operation_by_name("x"); Tensor model_y = graph.get_operation_by_name("y"); - Tensor loss = graph.get_operation_by_name("loss/loss"); + Tensor loss = graph.get_operation_by_name("loss/value"); //var optimizer_nodes = graph._nodes_by_name.Keys.Where(key => key.Contains("optimizer")).ToArray(); Tensor optimizer = graph.get_operation_by_name("loss/optimizer"); Tensor global_step = graph.get_operation_by_name("global_step"); - Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy"); - var stopwatch = Stopwatch.StartNew(); + Tensor accuracy = graph.get_operation_by_name("accuracy/value"); + stopwatch = Stopwatch.StartNew(); int i = 0; foreach (var (x_batch, y_batch, total) in train_batches) { i++; - var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total); - Console.WriteLine($"Training on batch {i}/{total}. Estimated training time: {estimate}"); var train_feed_dict = new Hashtable { [model_x] = x_batch, @@ -94,9 +93,14 @@ namespace TensorFlowNET.Examples.CnnTextClassification //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); //loss_value = result[2]; - var step = result[1]; + var step = result[1]; if (step % 10 == 0) + { + var estimate = TimeSpan.FromSeconds((stopwatch.Elapsed.TotalSeconds / i) * total); + Console.WriteLine($"Training on batch {i}/{total}. Estimated training time: {estimate}"); Console.WriteLine($"Step {step} loss: {result[2]}"); + } + if (step % 100 == 0) { continue; @@ -198,6 +202,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification { // download graph meta data var meta_file = model_name + ".meta"; + if (File.GetLastWriteTime(meta_file) < new DateTime(2019,05,11)) // delete old cached file which contains errors + File.Delete(meta_file); url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; Web.Download(url, "graph", meta_file); }