Browse Source

TextClassification: implemented batching using NumSharp slicing

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
b04478d904
1 changed files with 41 additions and 3 deletions
  1. +41
    -3
      test/TensorFlowNET.Examples/Text/TextClassificationTrain.cs

+ 41
- 3
test/TensorFlowNET.Examples/Text/TextClassificationTrain.cs View File

@@ -1,8 +1,10 @@
using System; using System;
using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using NumSharp;
using Tensorflow; using Tensorflow;
using Tensorflow.Keras.Engine; using Tensorflow.Keras.Engine;
using TensorFlowNET.Examples.Text.cnn_models; using TensorFlowNET.Examples.Text.cnn_models;
@@ -29,6 +31,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification


private const int CHAR_MAX_LEN = 1014; private const int CHAR_MAX_LEN = 1014;
private const int NUM_CLASS = 2; private const int NUM_CLASS = 2;
private const int BATCH_SIZE = 64;
private const int NUM_EPOCHS = 10;
protected float loss_value = 0; protected float loss_value = 0;


public bool Run() public bool Run()
@@ -54,13 +58,30 @@ namespace TensorFlowNET.Examples.CnnTextClassification
var meta_file = model_name + "_untrained.meta"; var meta_file = model_name + "_untrained.meta";
tf.train.import_meta_graph(Path.Join("graph", meta_file)); tf.train.import_meta_graph(Path.Join("graph", meta_file));


//sess.run(tf.global_variables_initializer());
//sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export

var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS);
var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1
double max_accuracy = 0;


Tensor is_training = graph.get_operation_by_name("is_training"); Tensor is_training = graph.get_operation_by_name("is_training");
Tensor model_x = graph.get_operation_by_name("x"); Tensor model_x = graph.get_operation_by_name("x");
Tensor model_y = graph.get_operation_by_name("y"); Tensor model_y = graph.get_operation_by_name("y");
//Tensor loss = graph.get_operation_by_name("loss");
//Tensor accuracy = graph.get_operation_by_name("accuracy");
Tensor loss = graph.get_operation_by_name("Variable");
Tensor accuracy = graph.get_operation_by_name("accuracy/accuracy");

foreach (var (x_batch, y_batch) in train_batches)
{
var train_feed_dict = new Hashtable
{
[model_x] = x_batch,
[model_y] = y_batch,
[is_training] = true,
};

//_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict)
}

return false; return false;
} }


@@ -122,6 +143,23 @@ namespace TensorFlowNET.Examples.CnnTextClassification
return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray()); return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray());
} }


private IEnumerable<(NDArray, NDArray)> batch_iter(int[][] raw_inputs, int[] raw_outputs, int batch_size, int num_epochs)
{
var inputs = np.array(raw_inputs);
var outputs = np.array(raw_outputs);

var num_batches_per_epoch = (len(inputs) - 1); // batch_size + 1
foreach (var epoch in range(num_epochs))
{
foreach (var batch_num in range(num_batches_per_epoch))
{
var start_index = batch_num * batch_size;
var end_index = Math.Min((batch_num + 1) * batch_size, len(inputs));
yield return (inputs[$"{start_index}:{end_index}"], outputs[$"{start_index}:{end_index}"]);
}
}
}

public void PrepareData() public void PrepareData()
{ {
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz"; string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";


Loading…
Cancel
Save