Browse Source

TextClassification: delete old cached meta graph which has a bug

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
47a7ac8017
1 changed files with 12 additions and 5 deletions
  1. +12
    -5
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs

+ 12
- 5
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -32,6 +32,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn public string model_name = "vd_cnn"; // word_cnn | char_cnn | vd_cnn | word_rnn | att_rnn | rcnn


private const int CHAR_MAX_LEN = 1014; private const int CHAR_MAX_LEN = 1014;
private const int WORD_MAX_LEN = 1014;
private const int NUM_CLASS = 2; private const int NUM_CLASS = 2;
private const int BATCH_SIZE = 64; private const int BATCH_SIZE = 64;
private const int NUM_EPOCHS = 10; private const int NUM_EPOCHS = 10;
@@ -58,6 +59,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification
Console.WriteLine("\tDONE "); Console.WriteLine("\tDONE ");


var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
Console.WriteLine("Training set size: " + train_x.shape[0]);
Console.WriteLine("Test set size: " + valid_x.shape[0]);


Console.WriteLine("Import graph..."); Console.WriteLine("Import graph...");
var meta_file = model_name + ".meta"; var meta_file = model_name + ".meta";
@@ -74,7 +77,6 @@ namespace TensorFlowNET.Examples.CnnTextClassification
Tensor model_x = graph.get_operation_by_name("x"); Tensor model_x = graph.get_operation_by_name("x");
Tensor model_y = graph.get_operation_by_name("y"); Tensor model_y = graph.get_operation_by_name("y");
Tensor loss = graph.get_operation_by_name("loss/value"); Tensor loss = graph.get_operation_by_name("loss/value");
//var optimizer_nodes = graph._nodes_by_name.Keys.Where(key => key.Contains("optimizer")).ToArray();
Tensor optimizer = graph.get_operation_by_name("loss/optimizer"); Tensor optimizer = graph.get_operation_by_name("loss/optimizer");
Tensor global_step = graph.get_operation_by_name("global_step"); Tensor global_step = graph.get_operation_by_name("global_step");
Tensor accuracy = graph.get_operation_by_name("accuracy/value"); Tensor accuracy = graph.get_operation_by_name("accuracy/value");
@@ -92,7 +94,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification
// original python: // original python:
//_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict);
//loss_value = result[2];
loss_value = result[2];
var step = result[1]; var step = result[1];
if (step % 10 == 0) if (step % 10 == 0)
{ {
@@ -177,7 +179,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification


private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)
{ {
var num_batches_per_epoch = (len(inputs) - 1) / batch_size;
var num_batches_per_epoch = (len(inputs) - 1) / batch_size +1;
var total_batches = num_batches_per_epoch * num_epochs; var total_batches = num_batches_per_epoch * num_epochs;
foreach (var epoch in range(num_epochs)) foreach (var epoch in range(num_epochs))
{ {
@@ -202,8 +204,13 @@ namespace TensorFlowNET.Examples.CnnTextClassification
{ {
// download graph meta data // download graph meta data
var meta_file = model_name + ".meta"; var meta_file = model_name + ".meta";
if (File.GetLastWriteTime(meta_file) < new DateTime(2019,05,11)) // delete old cached file which contains errors
File.Delete(meta_file);
var meta_path = Path.Combine("graph", meta_file);
if (File.GetLastWriteTime(meta_path) < new DateTime(2019, 05, 11))
{
// delete old cached file which contains errors
Console.WriteLine("Discarding cached file: " + meta_path);
File.Delete(meta_path);
}
url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file;
Web.Download(url, "graph", meta_file); Web.Download(url, "graph", meta_file);
} }


Loading…
Cancel
Save