| @@ -1,8 +1,10 @@ | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using NumSharp; | |||
| using Tensorflow; | |||
| using Tensorflow.Keras.Engine; | |||
| using TensorFlowNET.Examples.Text.cnn_models; | |||
| @@ -29,6 +31,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||
| private const int CHAR_MAX_LEN = 1014; | |||
| private const int NUM_CLASS = 2; | |||
| private const int BATCH_SIZE = 64; | |||
| private const int NUM_EPOCHS = 10; | |||
| protected float loss_value = 0; | |||
| public bool Run() | |||
| @@ -54,13 +58,30 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||
| var meta_file = model_name + "_untrained.meta"; | |||
| tf.train.import_meta_graph(Path.Join("graph", meta_file)); | |||
| //sess.run(tf.global_variables_initializer()); | |||
| //sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export | |||
| var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); | |||
| var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1 | |||
| double max_accuracy = 0; | |||
| Tensor is_training = graph.get_operation_by_name("is_training"); | |||
| Tensor model_x = graph.get_operation_by_name("x"); | |||
| Tensor model_y = graph.get_operation_by_name("y"); | |||
| //Tensor loss = graph.get_operation_by_name("loss"); | |||
| //Tensor accuracy = graph.get_operation_by_name("accuracy"); | |||
| Tensor loss = graph.get_operation_by_name("Variable"); | |||
| Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy"); | |||
| foreach (var (x_batch, y_batch) in train_batches) | |||
| { | |||
| var train_feed_dict = new Hashtable | |||
| { | |||
| [model_x] = x_batch, | |||
| [model_y] = y_batch, | |||
| [is_training] = true, | |||
| }; | |||
| //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) | |||
| } | |||
| return false; | |||
| } | |||
| @@ -122,6 +143,23 @@ namespace TensorFlowNET.Examples.CnnTextClassification | |||
| return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray()); | |||
| } | |||
| private IEnumerable<(NDArray, NDArray)> batch_iter(int[][] raw_inputs, int[] raw_outputs, int batch_size, int num_epochs) | |||
| { | |||
| var inputs = np.array(raw_inputs); | |||
| var outputs = np.array(raw_outputs); | |||
| var num_batches_per_epoch = (len(inputs) - 1); // batch_size + 1 | |||
| foreach (var epoch in range(num_epochs)) | |||
| { | |||
| foreach (var batch_num in range(num_batches_per_epoch)) | |||
| { | |||
| var start_index = batch_num * batch_size; | |||
| var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs)); | |||
| yield return (inputs[$"{start_index}:{end_index}"], outputs[$"{start_index}:{end_index}"]); | |||
| } | |||
| } | |||
| } | |||
| public void PrepareData() | |||
| { | |||
| string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; | |||