diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index 0f9960f0..115377c2 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -15,6 +15,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Visualization
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{E8340C61-12C1-4BEE-A340-403E7C1ACD82}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "scikit-learn", "..\scikit-learn.net\src\scikit-learn\scikit-learn.csproj", "{199DDAD8-4A6F-43B3-A560-C0393619E304}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -45,6 +47,10 @@ Global
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Debug|Any CPU.Build.0 = Debug|Any CPU
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.ActiveCfg = Release|Any CPU
{E8340C61-12C1-4BEE-A340-403E7C1ACD82}.Release|Any CPU.Build.0 = Release|Any CPU
+ {199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {199DDAD8-4A6F-43B3-A560-C0393619E304}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {199DDAD8-4A6F-43B3-A560-C0393619E304}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/src/TensorFlowNET.Utility/Web.cs b/src/TensorFlowNET.Utility/Web.cs
index dfaf5236..33df9c16 100644
--- a/src/TensorFlowNET.Utility/Web.cs
+++ b/src/TensorFlowNET.Utility/Web.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.IO;
+using System.Linq;
using System.Net;
using System.Text;
using System.Threading;
@@ -10,24 +11,31 @@ namespace TensorFlowNET.Utility
{
public class Web
{
- public static bool Download(string url, string file)
+ public static bool Download(string url, string destDir, string destFileName)
{
- if (File.Exists(file))
+ if (destFileName == null)
+ destFileName = url.Split(Path.DirectorySeparatorChar).Last();
+
+ Directory.CreateDirectory(destDir);
+
+ string relativeFilePath = Path.Combine(destDir, destFileName);
+
+ if (File.Exists(relativeFilePath))
{
- Console.WriteLine($"{file} already exists.");
+ Console.WriteLine($"{relativeFilePath} already exists.");
return false;
}
var wc = new WebClient();
- Console.WriteLine($"Downloading {file}");
- var download = Task.Run(() => wc.DownloadFile(url, file));
+ Console.WriteLine($"Downloading {relativeFilePath}");
+ var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
while (!download.IsCompleted)
{
Thread.Sleep(1000);
Console.Write(".");
}
Console.WriteLine("");
- Console.WriteLine($"Downloaded {file}");
+ Console.WriteLine($"Downloaded {relativeFilePath}");
return true;
}
diff --git a/test/TensorFlowNET.Examples/CnnTextClassification/CnnTextTrain.cs b/test/TensorFlowNET.Examples/CnnTextClassification/CnnTextTrain.cs
deleted file mode 100644
index ef4b0749..00000000
--- a/test/TensorFlowNET.Examples/CnnTextClassification/CnnTextTrain.cs
+++ /dev/null
@@ -1,58 +0,0 @@
-using NumSharp.Core;
-using System;
-using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using Tensorflow;
-
-namespace TensorFlowNET.Examples.CnnTextClassification
-{
- public class CnnTextTrain : Python, IExample
- {
- // Percentage of the training data to use for validation
- private float dev_sample_percentage = 0.1f;
- // Data source for the positive data.
- private string positive_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.pos";
- // Data source for the negative data.
- private string negative_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.neg";
- // Dimensionality of character embedding (default: 128)
- private int embedding_dim = 128;
- // Comma-separated filter sizes (default: '3,4,5')
- private string filter_sizes = "3,4,5";
- // Number of filters per filter size (default: 128)
- private int num_filters = 128;
- // Dropout keep probability (default: 0.5)
- private float dropout_keep_prob = 0.5f;
- // L2 regularization lambda (default: 0.0)
- private float l2_reg_lambda = 0.0f;
- // Batch Size (default: 64)
- private int batch_size = 64;
- // Number of training epochs (default: 200)
- private int num_epochs = 200;
- // Evaluate model on dev set after this many steps (default: 100)
- private int evaluate_every = 100;
- // Save model after this many steps (default: 100)
- private int checkpoint_every = 100;
- // Number of checkpoints to store (default: 5)
- private int num_checkpoints = 5;
- // Allow device soft device placement
- private bool allow_soft_placement = true;
- // Log placement of ops on devices
- private bool log_device_placement = false;
-
- public void Run()
- {
- var (x_train, y_train, vocab_processor, x_dev, y_dev) = preprocess();
- }
-
- public (NDArray, NDArray, NDArray, NDArray, NDArray) preprocess()
- {
- var (x_text, y) = DataHelpers.load_data_and_labels(positive_data_file, negative_data_file);
-
- // Build vocabulary
- int max_document_length = x_text.Select(x => x.Split(' ').Length).Max();
- var vocab_processor = learn.preprocessing.VocabularyProcessor(max_document_length)
- throw new NotImplementedException("");
- }
- }
-}
diff --git a/test/TensorFlowNET.Examples/CnnTextClassification/TextCNN.cs b/test/TensorFlowNET.Examples/CnnTextClassification/TextCNN.cs
deleted file mode 100644
index 92f5717f..00000000
--- a/test/TensorFlowNET.Examples/CnnTextClassification/TextCNN.cs
+++ /dev/null
@@ -1,16 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Text;
-using Tensorflow;
-
-namespace TensorFlowNET.Examples.CnnTextClassification
-{
- ///
- /// Convolutional Neural Network for Text Classification
- /// https://github.com/dennybritz/cnn-text-classification-tf
- ///
- public class TextCNN : Python
- {
-
- }
-}
diff --git a/test/TensorFlowNET.Examples/ImageRecognition.cs b/test/TensorFlowNET.Examples/ImageRecognition.cs
index a4bf05f4..47d0ac07 100644
--- a/test/TensorFlowNET.Examples/ImageRecognition.cs
+++ b/test/TensorFlowNET.Examples/ImageRecognition.cs
@@ -85,15 +85,14 @@ namespace TensorFlowNET.Examples
// get model file
string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip";
- string zipFile = Path.Join(dir, "inception5h.zip");
- Utility.Web.Download(url, zipFile);
+ Utility.Web.Download(url, dir, "inception5h.zip");
- Utility.Compress.UnZip(zipFile, dir);
+ Utility.Compress.UnZip(Path.Join(dir, "inception5h.zip"), dir);
// download sample picture
- string pic = Path.Join(dir, "img", "grace_hopper.jpg");
Directory.CreateDirectory(Path.Join(dir, "img"));
- Utility.Web.Download($"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/grace_hopper.jpg", pic);
+ url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/grace_hopper.jpg";
+ Utility.Web.Download(url, Path.Join(dir, "img"), "grace_hopper.jpg");
}
}
}
diff --git a/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs b/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs
index aaae7c1a..75be7738 100644
--- a/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs
+++ b/test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs
@@ -90,14 +90,14 @@ namespace TensorFlowNET.Examples
// get model file
string url = "https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz";
- string zipFile = Path.Join(dir, $"{pbFile}.tar.gz");
- Utility.Web.Download(url, zipFile);
+ Utility.Web.Download(url, dir, $"{pbFile}.tar.gz");
- Utility.Compress.ExtractTGZ(zipFile, dir);
+ Utility.Compress.ExtractTGZ(Path.Join(dir, $"{pbFile}.tar.gz"), dir);
// download sample picture
string pic = "grace_hopper.jpg";
- Utility.Web.Download($"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}", Path.Join(dir, pic));
+ url = $"https://raw.githubusercontent.com/tensorflow/tensorflow/master/tensorflow/examples/label_image/data/{pic}";
+ Utility.Web.Download(url, dir, pic);
}
}
}
diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
index 34a29361..b1349ec7 100644
--- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
+++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
@@ -12,6 +12,7 @@
+
diff --git a/test/TensorFlowNET.Examples/CnnTextClassification/DataHelpers.cs b/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
similarity index 53%
rename from test/TensorFlowNET.Examples/CnnTextClassification/DataHelpers.cs
rename to test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
index 043b6945..586a978a 100644
--- a/test/TensorFlowNET.Examples/CnnTextClassification/DataHelpers.cs
+++ b/test/TensorFlowNET.Examples/TextClassification/DataHelpers.cs
@@ -10,6 +10,44 @@ namespace TensorFlowNET.Examples.CnnTextClassification
{
public class DataHelpers
{
+ private const string TRAIN_PATH = "text_classification/dbpedia_csv/train.csv";
+ private const string TEST_PATH = "text_classification/dbpedia_csv/test.csv";
+
+ public static (int[][], int[], int) build_char_dataset(string step, string model, int document_max_len)
+ {
+ string alphabet = "abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:’'\"/|_#$%ˆ&*˜‘+=<>()[]{} ";
+ /*if (step == "train")
+ df = pd.read_csv(TRAIN_PATH, names =["class", "title", "content"]);*/
+ var char_dict = new Dictionary();
+ char_dict[""] = 0;
+ char_dict[""] = 1;
+ foreach (char c in alphabet)
+ char_dict[c.ToString()] = char_dict.Count;
+
+ var contents = File.ReadAllLines(TRAIN_PATH);
+
+ var x = new int[contents.Length][];
+ var y = new int[contents.Length];
+ for (int i = 0; i < contents.Length; i++)
+ {
+ string[] parts = contents[i].ToLower().Split(",\"").ToArray();
+ string content = parts[2];
+ content = content.Substring(0, content.Length - 1);
+ x[i] = new int[document_max_len];
+ for (int j = 0; j < document_max_len; j++)
+ {
+ if (j >= content.Length)
+ x[i][j] = char_dict[""];
+ else
+ x[i][j] = char_dict.ContainsKey(content[j].ToString()) ? char_dict[content[j].ToString()] : char_dict[""];
+ }
+
+ y[i] = int.Parse(parts[0]);
+ }
+
+ return (x, y, alphabet.Length + 2);
+ }
+
///
/// Loads MR polarity data from files, splits the data into words and generates labels.
/// Returns split sentences and labels.
@@ -20,8 +58,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification
public static (string[], NDArray) load_data_and_labels(string positive_data_file, string negative_data_file)
{
Directory.CreateDirectory("CnnTextClassification");
- Utility.Web.Download(positive_data_file, "CnnTextClassification/rt-polarity.pos");
- Utility.Web.Download(negative_data_file, "CnnTextClassification/rt-polarity.neg");
+ Utility.Web.Download(positive_data_file, "CnnTextClassification", "rt -polarity.pos");
+ Utility.Web.Download(negative_data_file, "CnnTextClassification", "rt-polarity.neg");
// Load data from files
var positive_examples = File.ReadAllLines("CnnTextClassification/rt-polarity.pos")
diff --git a/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
new file mode 100644
index 00000000..52ed5469
--- /dev/null
+++ b/test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs
@@ -0,0 +1,37 @@
+using NumSharp.Core;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+using Tensorflow;
+using TensorFlowNET.Utility;
+
+namespace TensorFlowNET.Examples.CnnTextClassification
+{
+ ///
+ /// https://github.com/dongjun-Lee/text-classification-models-tf
+ ///
+ public class TextClassificationTrain : Python, IExample
+ {
+ private string dataDir = "text_classification";
+ private string dataFileName = "dbpedia_csv.tar.gz";
+
+ private const int CHAR_MAX_LEN = 1014;
+
+ public void Run()
+ {
+ download_dbpedia();
+ Console.WriteLine("Building dataset...");
+ var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN);
+ var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15);
+ }
+
+ public void download_dbpedia()
+ {
+ string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
+ Web.Download(url, dataDir, dataFileName);
+ Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
+ }
+ }
+}
diff --git a/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs b/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs
index b57da319..cd59287b 100644
--- a/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs
+++ b/test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs
@@ -46,9 +46,8 @@ namespace TensorFlowNET.Examples
// get model file
string url = $"https://github.com/SciSharp/TensorFlow.NET/raw/master/data/{dataFile}";
- string zipFile = Path.Join(dir, $"imdb.zip");
- Utility.Web.Download(url, zipFile);
- Utility.Compress.UnZip(zipFile, dir);
+ Utility.Web.Download(url, dir, "imdb.zip");
+ Utility.Compress.UnZip(Path.Join(dir, $"imdb.zip"), dir);
// prepare training dataset
var x_train = ReadData(Path.Join(dir, "x_train.txt"));