| @@ -1,8 +1,10 @@ | |||||
| using System; | using System; | ||||
| using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.IO; | using System.IO; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using NumSharp; | |||||
| using Tensorflow; | using Tensorflow; | ||||
| using Tensorflow.Keras.Engine; | using Tensorflow.Keras.Engine; | ||||
| using TensorFlowNET.Examples.Text.cnn_models; | using TensorFlowNET.Examples.Text.cnn_models; | ||||
| @@ -29,6 +31,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| private const int CHAR_MAX_LEN = 1014; | private const int CHAR_MAX_LEN = 1014; | ||||
| private const int NUM_CLASS = 2; | private const int NUM_CLASS = 2; | ||||
| private const int BATCH_SIZE = 64; | |||||
| private const int NUM_EPOCHS = 10; | |||||
| protected float loss_value = 0; | protected float loss_value = 0; | ||||
| public bool Run() | public bool Run() | ||||
| @@ -54,13 +58,30 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| var meta_file = model_name + "_untrained.meta"; | var meta_file = model_name + "_untrained.meta"; | ||||
| tf.train.import_meta_graph(Path.Join("graph", meta_file)); | tf.train.import_meta_graph(Path.Join("graph", meta_file)); | ||||
| //sess.run(tf.global_variables_initializer()); | |||||
| //sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export | |||||
| var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); | |||||
| var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1 | |||||
| double max_accuracy = 0; | |||||
| Tensor is_training = graph.get_operation_by_name("is_training"); | Tensor is_training = graph.get_operation_by_name("is_training"); | ||||
| Tensor model_x = graph.get_operation_by_name("x"); | Tensor model_x = graph.get_operation_by_name("x"); | ||||
| Tensor model_y = graph.get_operation_by_name("y"); | Tensor model_y = graph.get_operation_by_name("y"); | ||||
| //Tensor loss = graph.get_operation_by_name("loss"); | |||||
| //Tensor accuracy = graph.get_operation_by_name("accuracy"); | |||||
| Tensor loss = graph.get_operation_by_name("Variable"); | |||||
| Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy"); | |||||
| foreach (var (x_batch, y_batch) in train_batches) | |||||
| { | |||||
| var train_feed_dict = new Hashtable | |||||
| { | |||||
| [model_x] = x_batch, | |||||
| [model_y] = y_batch, | |||||
| [is_training] = true, | |||||
| }; | |||||
| //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) | |||||
| } | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -122,6 +143,23 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||||
| return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray()); | return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray()); | ||||
| } | } | ||||
| private IEnumerable<(NDArray, NDArray)> batch_iter(int[][] raw_inputs, int[] raw_outputs, int batch_size, int num_epochs) | |||||
| { | |||||
| var inputs = np.array(raw_inputs); | |||||
| var outputs = np.array(raw_outputs); | |||||
| var num_batches_per_epoch = (len(inputs) - 1); // batch_size + 1 | |||||
| foreach (var epoch in range(num_epochs)) | |||||
| { | |||||
| foreach (var batch_num in range(num_batches_per_epoch)) | |||||
| { | |||||
| var start_index = batch_num * batch_size; | |||||
| var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs)); | |||||
| yield return (inputs[$"{start_index}:{end_index}"], outputs[$"{start_index}:{end_index}"]); | |||||
| } | |||||
| } | |||||
| } | |||||
| public void PrepareData() | public void PrepareData() | ||||
| { | { | ||||
| string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; | string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; | ||||