Browse Source

Word2Vec example works, but still has a minor display issue.

tags/v0.9
Oceania2018 6 years ago
parent
commit
ffebb85e78
3 changed files with 109 additions and 11 deletions
  1. +3
    -0
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +3
    -1
      test/TensorFlowNET.Examples/ObjectDetection.cs
  3. +103
    -10
      test/TensorFlowNET.Examples/Text/Word2Vec.cs

+ 3
- 0
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -80,6 +80,9 @@ namespace Tensorflow
case int val: case int val:
feed_dict_tensor[subfeed_t] = (NDArray)val; feed_dict_tensor[subfeed_t] = (NDArray)val;
break; break;
case int[] val:
feed_dict_tensor[subfeed_t] = (NDArray)val;
break;
case string val: case string val:
feed_dict_tensor[subfeed_t] = (NDArray)val; feed_dict_tensor[subfeed_t] = (NDArray)val;
break; break;


+ 3
- 1
test/TensorFlowNET.Examples/ObjectDetection.cs View File

@@ -123,7 +123,9 @@ namespace TensorFlowNET.Examples
} }
} }


bitmap.Save(Path.Join(imageDir, "output.jpg"));
string path = Path.Join(imageDir, "output.jpg");
bitmap.Save(path);
Console.WriteLine($"Processed image is saved as {path}");
} }


private void drawObjectOnBitmap(Bitmap bmp, Rectangle rect, float score, string name) private void drawObjectOnBitmap(Bitmap bmp, Rectangle rect, float score, string name)


+ 103
- 10
test/TensorFlowNET.Examples/Text/Word2Vec.cs View File

@@ -1,4 +1,5 @@
using System;
using NumSharp;
using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
@@ -22,13 +23,15 @@ namespace TensorFlowNET.Examples
// Training Parameters // Training Parameters
float learning_rate = 0.1f; float learning_rate = 0.1f;
int batch_size = 128; int batch_size = 128;
int num_steps = 3000000;
int display_step = 10000;
int eval_step = 200000;
int num_steps = 30000; //3000000;
int display_step = 1000; //10000;
int eval_step = 5000;//200000;


// Evaluation Parameters // Evaluation Parameters
string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" }; string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" };
string[] text_words; string[] text_words;
List<WordId> word2id;
int[] data;


// Word2Vec Parameters // Word2Vec Parameters
int embedding_size = 200; // Dimension of the embedding vector int embedding_size = 200; // Dimension of the embedding vector
@@ -38,7 +41,9 @@ namespace TensorFlowNET.Examples
int num_skips = 2; // How many times to reuse an input to generate a label int num_skips = 2; // How many times to reuse an input to generate a label
int num_sampled = 64; // Number of negative examples to sample int num_sampled = 64; // Number of negative examples to sample


int data_index;
int data_index = 0;
int top_k = 8; // number of nearest neighbors
float average_loss = 0;


public bool Run() public bool Run()
{ {
@@ -48,21 +53,109 @@ namespace TensorFlowNET.Examples


tf.train.import_meta_graph("graph/word2vec.meta"); tf.train.import_meta_graph("graph/word2vec.meta");


// Input data
Tensor X = graph.OperationByName("Placeholder");
// Input label
Tensor Y = graph.OperationByName("Placeholder_1");

// Compute the average NCE loss for the batch
Tensor loss_op = graph.OperationByName("Mean");
// Define the optimizer
var train_op = graph.OperationByName("GradientDescent");
Tensor cosine_sim_op = graph.OperationByName("MatMul_1");

// Initialize the variables (i.e. assign their default value) // Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer(); var init = tf.global_variables_initializer();


with(tf.Session(graph), sess => with(tf.Session(graph), sess =>
{ {
// Run the initializer
sess.run(init); sess.run(init);

var x_test = (from word in eval_words
join id in word2id on word equals id.Word into wi
from wi2 in wi.DefaultIfEmpty()
select wi2 == null ? 0 : wi2.Id).ToArray();

foreach (var step in range(1, num_steps + 1))
{
// Get a new batch of data
var (batch_x, batch_y) = next_batch(batch_size, num_skips, skip_window);

var result = sess.run(new ITensorOrOperation[] { train_op, loss_op }, new FeedItem(X, batch_x), new FeedItem(Y, batch_y));
average_loss += result[1];

if (step % display_step == 0 || step == 1)
{
if (step > 1)
average_loss /= display_step;

print($"Step {step}, Average Loss= {average_loss.ToString("F4")}");
average_loss = 0;
}

// Evaluation
if (step % eval_step == 0 || step == 1)
{
print("Evaluation...");
var sim = sess.run(cosine_sim_op, new FeedItem(X, x_test));
foreach(var i in range(len(eval_words)))
{
var nearest = sim[i].argsort<float>()
.Data<float>()
.Take(top_k)
.ToArray();
string log_str = $"\"{eval_words[i]}\" nearest neighbors:";
foreach (var k in range(top_k))
log_str = $"{log_str} {word2id.First(x => x.Id == nearest[k]).Word},";
print(log_str);
}
}
}
}); });


return false;
return average_loss < 100;
} }


// Generate training batch for the skip-gram model // Generate training batch for the skip-gram model
private void next_batch()
private (NDArray, NDArray) next_batch(int batch_size, int num_skips, int skip_window)
{ {
var batch = np.ndarray((batch_size), dtype: np.int32);
var labels = np.ndarray((batch_size, 1), dtype: np.int32);
// get window size (words left and right + current one)
int span = 2 * skip_window + 1;
var buffer = new Queue<int>(span);
if (data_index + span > data.Length)
data_index = 0;
data.Skip(data_index).Take(span).ToList().ForEach(x => buffer.Enqueue(x));
data_index += span;

foreach (var i in range(batch_size / num_skips))
{
var context_words = range(span).Where(x => x != skip_window).ToArray();
var words_to_use = new int[] { 1, 6 };
foreach(var (j, context_word) in enumerate(words_to_use))
{
batch[i * num_skips + j] = buffer.ElementAt(skip_window);
labels[i * num_skips + j, 0] = buffer.ElementAt(context_word);
}

if (data_index == len(data))
{
//buffer.extend(data[0:span]);
data_index = span;
}
else
{
buffer.Enqueue(data[data_index]);
data_index += 1;
}
}

// Backtrack a little bit to avoid skipping words in the end of a batch
data_index = (data_index + len(data) - span) % len(data);


return (batch, labels);
} }


public void PrepareData() public void PrepareData()
@@ -80,7 +173,7 @@ namespace TensorFlowNET.Examples
int wordId = 0; int wordId = 0;
text_words = File.ReadAllText(@"word2vec\text8").Trim().ToLower().Split(); text_words = File.ReadAllText(@"word2vec\text8").Trim().ToLower().Split();
// Build the dictionary and replace rare words with UNK token // Build the dictionary and replace rare words with UNK token
var word2id = text_words.GroupBy(x => x)
word2id = text_words.GroupBy(x => x)
.Select(x => new WordId .Select(x => new WordId
{ {
Word = x.Key, Word = x.Key,
@@ -97,10 +190,10 @@ namespace TensorFlowNET.Examples
.ToList(); .ToList();


// Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary // Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary
var data = (from word in text_words
data = (from word in text_words
join id in word2id on word equals id.Word into wi join id in word2id on word equals id.Word into wi
from wi2 in wi.DefaultIfEmpty() from wi2 in wi.DefaultIfEmpty()
select wi2 == null ? 0 : wi2.Id).ToList();
select wi2 == null ? 0 : wi2.Id).ToArray();


word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) }); word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) });




Loading…
Cancel
Save