Browse Source

TextClassification: reverted to simpler shuffling and created dbpedia_subset.zip for debugging

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
b5d8165438
4 changed files with 79 additions and 91 deletions
  1. BIN
      data/dbpedia_subset.zip
  2. +3
    -1
      test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs
  3. +54
    -90
      test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs
  4. +22
    -0
      test/TensorFlowNET.Examples/Utility/ArrayShuffling.cs

BIN
data/dbpedia_subset.zip View File


+ 3
- 1
test/TensorFlowNET.Examples/TextProcess/DataHelpers.cs View File

@@ -5,6 +5,7 @@ using System.IO;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using TensorFlowNET.Examples.Utility;

namespace TensorFlowNET.Examples
{
@@ -25,7 +26,8 @@ namespace TensorFlowNET.Examples
char_dict["<unk>"] = 1;
foreach (char c in alphabet)
char_dict[c.ToString()] = char_dict.Count;
var contents = File.ReadAllLines(TRAIN_PATH);
var contents = new Random(17).Shuffle( File.ReadAllLines(TRAIN_PATH));
//File.WriteAllLines("text_classification/dbpedia_csv/train_6400.csv", contents.Take(6400));
var size = limit == null ? contents.Length : limit.Value;

var x = new int[size][];


+ 54
- 90
test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs View File

@@ -59,8 +59,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification
Console.WriteLine("\tDONE ");

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
Console.WriteLine("Training set size: " + train_x.Length);
Console.WriteLine("Test set size: " + valid_x.Length);
Console.WriteLine("Training set size: " + train_x.len);
Console.WriteLine("Test set size: " + valid_x.len);

Console.WriteLine("Import graph...");
var meta_file = model_name + ".meta";
@@ -164,106 +164,70 @@ namespace TensorFlowNET.Examples.CnnTextClassification
}
// TODO: this originally is an SKLearn utility function. it randomizes train and test which we don't do here
//private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f)
//{
// Console.WriteLine("Splitting in Training and Testing data...");
// int len = x.shape[0];
// //int classes = y.Data<int>().Distinct().Count();
// //int samples = len / classes;
// int train_size = (int)Math.Round(len * (1 - test_size));
// var train_x = x[new Slice(stop: train_size), new Slice()];
// var valid_x = x[new Slice(start: train_size + 1), new Slice()];
// var train_y = y[new Slice(stop: train_size)];
// var valid_y = y[new Slice(start: train_size + 1)];
// Console.WriteLine("\tDONE");
// return (train_x, valid_x, train_y, valid_y);
//}

private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
{
Console.WriteLine("Splitting in Training and Testing data...");
var stopwatch = Stopwatch.StartNew();
int len = x.Length;
//int classes = y.Distinct().Count();
//int samples = len / classes;
int train_size = int.Parse((len * (1 - test_size)).ToString());

//var train_x = new List<int[]>();
//var valid_x = new List<int[]>();
//var train_y = new List<int>();
//var valid_y = new List<int>();
private (NDArray, NDArray, NDArray, NDArray) train_test_split(NDArray x, NDArray y, float test_size = 0.3f)
{
Console.WriteLine("Splitting in Training and Testing data...");
int len = x.shape[0];
//int classes = y.Data<int>().Distinct().Count();
//int samples = len / classes;
int train_size = (int)Math.Round(len * (1 - test_size));
var train_x = x[new Slice(stop: train_size), new Slice()];
var valid_x = x[new Slice(start: train_size + 1), new Slice()];
var train_y = y[new Slice(stop: train_size)];
var valid_y = y[new Slice(start: train_size + 1)];
Console.WriteLine("\tDONE");
return (train_x, valid_x, train_y, valid_y);
}

//for (int i = 0; i < classes; i++)
//{
// for (int j = 0; j < samples; j++)
// {
// int idx = i * samples + j;
// if (idx < train_size + samples * i)
// {
// train_x.Add(x[idx]);
// train_y.Add(y[idx]);
// }
// else
// {
// valid_x.Add(x[idx]);
// valid_y.Add(y[idx]);
// }
// }
//}
var random = new Random(17);
//private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
//{
// Console.WriteLine("Splitting in Training and Testing data...");
// var stopwatch = Stopwatch.StartNew();
// int len = x.Length;
// int train_size = int.Parse((len * (1 - test_size)).ToString());
// var random = new Random(17);

// we collect indices of labels
var labels = new Dictionary<int, HashSet<int>>();
var shuffled_indices = Shuffle<int>(random, range(len).ToArray());
foreach (var i in shuffled_indices)
{
var label = y[i];
if (!labels.ContainsKey(i))
labels[label] = new HashSet<int>();
labels[label].Add(i);
}
// // we collect indices of labels
// var labels = new Dictionary<int, HashSet<int>>();
// var shuffled_indices = random.Shuffle<int>(range(len).ToArray());
// foreach (var i in shuffled_indices)
// {
// var label = y[i];
// if (!labels.ContainsKey(i))
// labels[label] = new HashSet<int>();
// labels[label].Add(i);
// }

var train_x = new int[train_size][];
var valid_x = new int[len - train_size][];
var train_y = new int[train_size];
var valid_y = new int[len - train_size];
// var train_x = new int[train_size][];
// var valid_x = new int[len - train_size][];
// var train_y = new int[train_size];
// var valid_y = new int[len - train_size];
FillWithShuffledLabels(x, y, train_x, train_y, random, labels);
FillWithShuffledLabels(x, y, valid_x, valid_y, random, labels);
// FillWithShuffledLabels(x, y, train_x, train_y, random, labels);
// FillWithShuffledLabels(x, y, valid_x, valid_y, random, labels);

Console.WriteLine("\tDONE " + stopwatch.Elapsed);
return (train_x, valid_x, train_y, valid_y);
}
// Console.WriteLine("\tDONE " + stopwatch.Elapsed);
// return (train_x, valid_x, train_y, valid_y);
//}

private static void FillWithShuffledLabels(int[][] x, int[] y, int[][] shuffled_x, int[] shuffled_y, Random random, Dictionary<int, HashSet<int>> labels)
{
int i = 0;
var label_keys = labels.Keys.ToArray();
while (i < shuffled_x.Length)
{
foreach (var key in Shuffle<int>(random, labels.Keys.ToArray()))
{
var set = labels[key];
var index = set.First();
if (set.Count == 0)
labels.Remove(key); // remove the set as it is empty
shuffled_x[i] = x[index];
shuffled_y[i] = y[index];
i++;
}
}
}

public static T[] Shuffle<T>(Random rng, T[] array)
{
int n = array.Length;
while (n > 1)
{
int k = rng.Next(n--);
T temp = array[n];
array[n] = array[k];
array[k] = temp;
var key = label_keys[random.Next(label_keys.Length)];
var set = labels[key];
var index = set.First();
if (set.Count == 0)
{
labels.Remove(key); // remove the set as it is empty
label_keys = labels.Keys.ToArray();
}
shuffled_x[i] = x[index];
shuffled_y[i] = y[index];
i++;
}
return array;
}

private IEnumerable<(NDArray, NDArray, int)> batch_iter(NDArray inputs, NDArray outputs, int batch_size, int num_epochs)


+ 22
- 0
test/TensorFlowNET.Examples/Utility/ArrayShuffling.cs View File

@@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Text;
namespace TensorFlowNET.Examples.Utility
{
public static class ArrayShuffling
{
public static T[] Shuffle<T>(this Random rng, T[] array)
{
int n = array.Length;
while (n > 1)
{
int k = rng.Next(n--);
T temp = array[n];
array[n] = array[k];
array[k] = temp;
}
return array;
}
}
}

Loading…
Cancel
Save