From 86c142e067e1d36847db7df476a362fa1676bc60 Mon Sep 17 00:00:00 2001 From: Meinrad Recheis Date: Fri, 10 May 2019 22:35:09 +0200 Subject: [PATCH] test classification: improved batching performance with slicing --- .../TextProcess/DataHelpers.cs | 26 ++-- .../TextProcess/TextClassificationTrain.cs | 121 ++++++++++-------- 2 files changed, 89 insertions(+), 58 deletions(-) diff --git a/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs b/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs index 4bc1d84d..bae875e6 100644 --- a/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs +++ b/test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs @@ -13,7 +13,7 @@ namespace TensorFlowNET.Examples private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv"; private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv"; - public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len, int? limit = null) + public static (NDArray, NDArray, int) build_char_dataset(string step, string model, int document_max_len, int? limit = null) { if (model != "vd_cnn") throw new NotImplementedException(model); @@ -29,22 +29,32 @@ namespace TensorFlowNET.Examples var contents = File.ReadAllLines(TRAIN_PATH); var size = limit == null ? contents.Length : limit.Value; - var x = new int[size][]; - var y = new int[size]; + var x = new NDArray(np.int32, new Shape(size, document_max_len)); + var y = new NDArray(np.int32, new Shape(size)); + var tenth = size / 10; + var percent = 0; for (int i = 0; i < size; i++) { + if ((i + 1) % tenth == 0) + { + percent += 10; + Console.WriteLine($"\t{percent}%"); + } + string[] parts = contents[i].ToLower().Split(",\"").ToArray(); string content = parts[2]; content = content.Substring(0, content.Length - 1); - x[i] = new int[document_max_len]; + var a = new int[document_max_len]; for (int j = 0; j < document_max_len; j++) { if (j >= content.Length) - x[i][j] = char_dict[""]; - else - x[i][j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict[""]; + a[j] = char_dict[""]; + //x[i, j] = char_dict[""]; + else + a[j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict[""]; + //x[i, j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict[""]; } - + x[i] = a; y[i] = int.Parse(parts[0]); } diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index 94f990cf..e4f92cd2 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -39,28 +39,30 @@ namespace TensorFlowNET.Examples.CnnTextClassification public bool Run() { PrepareData(); - return with(tf.Session(), sess => + var graph = tf.Graph().as_default(); + return with(tf.Session(graph), sess => { if (ImportGraph) - return RunWithImportedGraph(sess); + return RunWithImportedGraph(sess, graph); else - return RunWithBuiltGraph(sess); + return RunWithBuiltGraph(sess, graph); }); } - protected virtual bool RunWithImportedGraph(Session sess) + protected virtual bool RunWithImportedGraph(Session sess, Graph graph) { - var graph = tf.Graph().as_default(); Console.WriteLine("Building dataset..."); var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); + Console.WriteLine("\tDONE"); var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f); - var meta_file = model_name + "_untrained.meta"; + Console.WriteLine("Import graph..."); + var meta_file = model_name + ".meta"; tf.train.import_meta_graph(Path.Join("graph", meta_file)); - - //sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export - + Console.WriteLine("\tDONE"); + //sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export + var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1 double max_accuracy = 0; @@ -68,25 +70,64 @@ namespace TensorFlowNET.Examples.CnnTextClassification Tensor is_training = graph.get_operation_by_name("is_training"); Tensor model_x = graph.get_operation_by_name("x"); Tensor model_y = graph.get_operation_by_name("y"); - Tensor loss = graph.get_operation_by_name("Variable"); + Tensor loss = graph.get_operation_by_name("loss/loss"); + //var optimizer_nodes = graph._nodes_by_name.Keys.Where(key => key.Contains("optimizer")).ToArray(); + Tensor optimizer = graph.get_operation_by_name("loss/optimizer"); + Tensor global_step = graph.get_operation_by_name("global_step"); Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy"); + int i = 0; foreach (var (x_batch, y_batch) in train_batches) { + i++; + Console.WriteLine("Training on batch " + i); var train_feed_dict = new Hashtable { [model_x] = x_batch, [model_y] = y_batch, [is_training] = true, - }; - + }; + // original python: //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) + var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); + loss_value = result[2]; + var step = result[1]; + if (step % 100 == 0) + Console.WriteLine($"Step {step} loss: {loss_value}"); + if (step % 2000 == 0) + { + continue; + // # Test accuracy with validation data for each epoch. + var valid_batches = batch_iter(valid_x, valid_y, BATCH_SIZE, 1); + var (sum_accuracy, cnt) = (0, 0); + foreach (var (valid_x_batch, valid_y_batch) in valid_batches) + { + // valid_feed_dict = { + // model.x: valid_x_batch, + // model.y: valid_y_batch, + // model.is_training: False + // } + + // accuracy = sess.run(model.accuracy, feed_dict = valid_feed_dict) + // sum_accuracy += accuracy + // cnt += 1 + } + // valid_accuracy = sum_accuracy / cnt + + // print("\nValidation Accuracy = {1}\n".format(step // num_batches_per_epoch, sum_accuracy / cnt)) + + // # Save model + // if valid_accuracy > max_accuracy: + // max_accuracy = valid_accuracy + // saver.save(sess, "{0}/{1}.ckpt".format(args.model, args.model), global_step = step) + // print("Model is saved.\n") + } } return false; } - protected virtual bool RunWithBuiltGraph(Session session) + protected virtual bool RunWithBuiltGraph(Session session, Graph graph) { Console.WriteLine("Building dataset..."); var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", model_name, CHAR_MAX_LEN, DataLimit); @@ -104,51 +145,31 @@ namespace TensorFlowNET.Examples.CnnTextClassification throw new NotImplementedException(); break; case "vd_cnn": - model=new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); + model = new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS); break; } // todo train the model return false; } - private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f) + // TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here + private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f) { - int len = x.Length; - int classes = y.Distinct().Count(); - int samples = len / classes; - int train_size = int.Parse((samples * (1 - test_size)).ToString()); - - var train_x = new List(); - var valid_x = new List(); - var train_y = new List(); - var valid_y = new List(); - - for (int i = 0; i < classes; i++) - { - for (int j = 0; j < samples; j++) - { - int idx = i * samples + j; - if (idx < train_size + samples * i) - { - train_x.Add(x[idx]); - train_y.Add(y[idx]); - } - else - { - valid_x.Add(x[idx]); - valid_y.Add(y[idx]); - } - } - } - - return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray()); + Console.WriteLine("Splitting in Training and Testing data..."); + int len = x.shape[0]; + //int classes = y.Data().Distinct().Count(); + //int samples = len / classes; + int train_size = (int)Math.Round(len * (1 - test_size)); + var train_x = x[new Slice(stop:train_size), new Slice()]; + var valid_x = x[new Slice(start: train_size+1), new Slice()]; + var train_y = y[new Slice(stop: train_size)]; + var valid_y = y[new Slice(start: train_size + 1)]; + Console.WriteLine("\tDONE"); + return (train_x, valid_x, train_y, valid_y); } - private IEnumerable<(NDArray, NDArray)> batch_iter(int[][] raw_inputs, int[] raw_outputs, int batch_size, int num_epochs) + private IEnumerable<(NDArray, NDArray)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs) { - var inputs = np.array(raw_inputs); - var outputs = np.array(raw_outputs); - var num_batches_per_epoch = (len(inputs) - 1); // batch_size + 1 foreach (var epoch in range(num_epochs)) { @@ -156,7 +177,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification { var start_index = batch_num * batch_size; var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs)); - yield return (inputs[$"{start_index}:{end_index}"], outputs[$"{start_index}:{end_index}"]); + yield return (inputs[new Slice(start_index, end_index)], outputs[new Slice(start_index,end_index)]); } } } @@ -170,7 +191,7 @@ namespace TensorFlowNET.Examples.CnnTextClassification if (ImportGraph) { // download graph meta data - var meta_file = model_name + "_untrained.meta"; + var meta_file = model_name + ".meta"; url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/" + meta_file; Web.Download(url, "graph", meta_file); }