From e615351373c37791c749d06333e26e902076616e Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Fri, 10 May 2019 18:55:03 -0500 Subject: [PATCH] add MarShall.Copy for Boolean when creating Tensor. --- src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs | 3 +++ test/TensorFlowNET.Examples/Program.cs | 1 + .../TextProcess/TextClassificationTrain.cs | 12 +++++++----- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index 3b8b65dd..5ebaa092 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -55,6 +55,9 @@ namespace Tensorflow var nd1 = nd.ravel(); switch (nd.dtype.Name) { + case "Boolean": + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + break; case "Int16": Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); break; diff --git a/test/TensorFlowNET.Examples/Program.cs b/test/TensorFlowNET.Examples/Program.cs index 01928b88..9fd9c714 100644 --- a/test/TensorFlowNET.Examples/Program.cs +++ b/test/TensorFlowNET.Examples/Program.cs @@ -64,6 +64,7 @@ namespace TensorFlowNET.Examples disabled.ForEach(x => Console.WriteLine($"{x} is Disabled!", Color.Tan)); errors.ForEach(x => Console.WriteLine($"{x} is Failed!", Color.Red)); + Console.Write("Please [Enter] to quit."); Console.ReadLine(); } } diff --git a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs index e4f92cd2..15f1e55b 100644 --- a/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs +++ b/test/TensorFlowNET.Examples/TextProcess/TextClassificationTrain.cs @@ -61,7 +61,8 @@ namespace TensorFlowNET.Examples.CnnTextClassification var meta_file = model_name + ".meta"; tf.train.import_meta_graph(Path.Join("graph", meta_file)); Console.WriteLine("\tDONE"); - //sess.run(tf.global_variables_initializer()); // not necessary here, has already been done before meta graph export + // definitely necessary, otherwize will get the exception of "use uninitialized variable" + sess.run(tf.global_variables_initializer()); var train_batches = batch_iter(train_x, train_y, BATCH_SIZE, NUM_EPOCHS); var num_batches_per_epoch = (len(train_x) - 1); // BATCH_SIZE + 1 @@ -89,12 +90,13 @@ namespace TensorFlowNET.Examples.CnnTextClassification }; // original python: //_, step, loss = sess.run([model.optimizer, model.global_step, model.loss], feed_dict = train_feed_dict) - var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); - loss_value = result[2]; + var result = sess.run(new Tensor[] { optimizer, global_step, loss }, train_feed_dict); + // exception here, loss value seems like a float[] + //loss_value = result[2]; var step = result[1]; + if (step % 10 == 0) + Console.WriteLine($"Step {step} loss: {result[2]}"); if (step % 100 == 0) - Console.WriteLine($"Step {step} loss: {loss_value}"); - if (step % 2000 == 0) { continue; // # Test accuracy with validation data for each epoch.