From 85319e0febc78dc71abf8b2cc17c6ce771ce4311 Mon Sep 17 00:00:00 2001 From: Kerry Jiang Date: Wed, 31 Jul 2019 04:58:56 -0700 Subject: [PATCH 01/11] started to use MnistModelLoader in Tensorflow.Hub (#330) --- .gitignore | 4 + .../BasicModels/KMeansClustering.cs | 28 ++-- .../BasicModels/LogisticRegression.cs | 26 ++-- .../BasicModels/NearestNeighbor.cs | 10 +- .../ImageProcessing/DigitRecognitionCNN.cs | 16 +-- .../ImageProcessing/DigitRecognitionNN.cs | 40 ++---- .../ImageProcessing/DigitRecognitionRNN.cs | 18 +-- .../TensorFlowNET.Examples.csproj | 1 + .../Utility/DataSetMnist.cs | 95 ------------- .../Utility/Datasets.cs | 46 ------ .../Utility/IDataSet.cs | 10 -- test/TensorFlowNET.Examples/Utility/MNIST.cs | 131 ------------------ 12 files changed, 69 insertions(+), 356 deletions(-) delete mode 100644 test/TensorFlowNET.Examples/Utility/DataSetMnist.cs delete mode 100644 test/TensorFlowNET.Examples/Utility/Datasets.cs delete mode 100644 test/TensorFlowNET.Examples/Utility/IDataSet.cs delete mode 100644 test/TensorFlowNET.Examples/Utility/MNIST.cs diff --git a/.gitignore b/.gitignore index eee1dc7b..ce600fbb 100644 --- a/.gitignore +++ b/.gitignore @@ -332,3 +332,7 @@ src/TensorFlowNET.Native/bazel-* src/TensorFlowNET.Native/c_api.h /.vscode test/TensorFlowNET.Examples/mnist + + +# training model resources +.resources diff --git a/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs b/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs index c0ca95b3..3b52a75e 100644 --- a/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs +++ b/test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs @@ -18,7 +18,7 @@ using NumSharp; using System; using System.Diagnostics; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples @@ -39,7 +39,7 @@ namespace TensorFlowNET.Examples public int? test_size = null; public int batch_size = 1024; // The number of samples per batch - Datasets mnist; + Datasets mnist; NDArray full_data_x; int num_steps = 20; // Total steps to train int k = 25; // The number of clusters @@ -62,19 +62,31 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size); - full_data_x = mnist.train.data; + var loader = new MnistModelLoader(); + + var setting = new ModelLoadSetting + { + TrainDir = ".resources/mnist", + OneHot = true, + TrainSize = train_size, + ValidationSize = validation_size, + TestSize = test_size + }; + + mnist = loader.LoadAsync(setting).Result; + + full_data_x = mnist.Train.Data; // download graph meta data string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta"; - Web.Download(url, "graph", "kmeans.meta"); + loader.DownloadAsync(url, ".resources/graph", "kmeans.meta").Wait(); } public Graph ImportGraph() { var graph = tf.Graph().as_default(); - tf.train.import_meta_graph("graph/kmeans.meta"); + tf.train.import_meta_graph(".resources/graph/kmeans.meta"); return graph; } @@ -132,7 +144,7 @@ namespace TensorFlowNET.Examples sw.Start(); foreach (var i in range(idx.Length)) { - var x = mnist.train.labels[i]; + var x = mnist.Train.Labels[i]; counts[idx[i]] += x; } @@ -153,7 +165,7 @@ namespace TensorFlowNET.Examples var accuracy_op = tf.reduce_mean(cast); // Test Model - var (test_x, test_y) = (mnist.test.data, mnist.test.labels); + var (test_x, test_y) = (mnist.Test.Data, mnist.Test.Labels); result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y)); accuray_test = result; print($"Test Accuracy: {accuray_test}"); diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index 185dd1fe..1d7808b7 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -19,7 +19,7 @@ using System; using System.Diagnostics; using System.IO; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples @@ -45,7 +45,7 @@ namespace TensorFlowNET.Examples private float learning_rate = 0.01f; private int display_step = 1; - Datasets mnist; + Datasets mnist; public bool Run() { @@ -84,11 +84,11 @@ namespace TensorFlowNET.Examples sw.Start(); var avg_cost = 0.0f; - var total_batch = mnist.train.num_examples / batch_size; + var total_batch = mnist.Train.NumOfExamples / batch_size; // Loop over all batches foreach (var i in range(total_batch)) { - var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); + var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(batch_size); // Run optimization op (backprop) and cost op (to get loss value) var result = sess.run(new object[] { optimizer, cost }, new FeedItem(x, batch_xs), @@ -115,7 +115,7 @@ namespace TensorFlowNET.Examples var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); // Calculate accuracy var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); - float acc = accuracy.eval(new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels)); + float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); print($"Accuracy: {acc.ToString("F4")}"); return acc > 0.9; @@ -124,23 +124,23 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size); + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: train_size, validationSize: validation_size, testSize: test_size).Result; } public void SaveModel(Session sess) { var saver = tf.train.Saver(); - var save_path = saver.save(sess, "logistic_regression/model.ckpt"); - tf.train.write_graph(sess.graph, "logistic_regression", "model.pbtxt", as_text: true); + var save_path = saver.save(sess, ".resources/logistic_regression/model.ckpt"); + tf.train.write_graph(sess.graph, ".resources/logistic_regression", "model.pbtxt", as_text: true); - FreezeGraph.freeze_graph(input_graph: "logistic_regression/model.pbtxt", + FreezeGraph.freeze_graph(input_graph: ".resources/logistic_regression/model.pbtxt", input_saver: "", input_binary: false, - input_checkpoint: "logistic_regression/model.ckpt", + input_checkpoint: ".resources/logistic_regression/model.ckpt", output_node_names: "Softmax", restore_op_name: "save/restore_all", filename_tensor_name: "save/Const:0", - output_graph: "logistic_regression/model.pb", + output_graph: ".resources/logistic_regression/model.pb", clear_devices: true, initializer_nodes: ""); } @@ -148,7 +148,7 @@ namespace TensorFlowNET.Examples public void Predict(Session sess) { var graph = new Graph().as_default(); - graph.Import(Path.Join("logistic_regression", "model.pb")); + graph.Import(Path.Join(".resources/logistic_regression", "model.pb")); // restoring the model // var saver = tf.train.import_meta_graph("logistic_regression/tensorflowModel.ckpt.meta"); @@ -159,7 +159,7 @@ namespace TensorFlowNET.Examples var input = x.outputs[0]; // predict - var (batch_xs, batch_ys) = mnist.train.next_batch(10); + var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(10); var results = sess.run(output, new FeedItem(input, batch_xs[np.arange(1)])); if (results.argmax() == (batch_ys[0] as NDArray).argmax()) diff --git a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs index 86ecd281..d1d867a2 100644 --- a/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs +++ b/test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs @@ -17,7 +17,7 @@ using NumSharp; using System; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples @@ -31,7 +31,7 @@ namespace TensorFlowNET.Examples { public bool Enabled { get; set; } = true; public string Name => "Nearest Neighbor"; - Datasets mnist; + Datasets mnist; NDArray Xtr, Ytr, Xte, Yte; public int? TrainSize = null; public int ValidationSize = 5000; @@ -84,10 +84,10 @@ namespace TensorFlowNET.Examples public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize); + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true, trainSize: TrainSize, validationSize: ValidationSize, testSize: TestSize).Result; // In this example, we limit mnist data - (Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) - (Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing + (Xtr, Ytr) = mnist.Train.GetNextBatch(TrainSize == null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates) + (Xte, Yte) = mnist.Test.GetNextBatch(TestSize == null ? 200 : TestSize.Value / 100); // 200 for testing } public Graph ImportGraph() diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs index a5c757b9..d2a1b9f4 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs @@ -18,7 +18,7 @@ using NumSharp; using System; using System.Diagnostics; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples.ImageProcess @@ -46,7 +46,7 @@ namespace TensorFlowNET.Examples.ImageProcess int epochs = 5; // accuracy > 98% int batch_size = 100; float learning_rate = 0.001f; - Datasets mnist; + Datasets mnist; // Network configuration // 1st Convolutional Layer @@ -310,14 +310,14 @@ namespace TensorFlowNET.Examples.ImageProcess public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true); - (x_train, y_train) = Reformat(mnist.train.data, mnist.train.labels); - (x_valid, y_valid) = Reformat(mnist.validation.data, mnist.validation.labels); - (x_test, y_test) = Reformat(mnist.test.data, mnist.test.labels); + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; + (x_train, y_train) = Reformat(mnist.Train.Data, mnist.Train.Labels); + (x_valid, y_valid) = Reformat(mnist.Validation.Data, mnist.Validation.Labels); + (x_test, y_test) = Reformat(mnist.Test.Data, mnist.Test.Labels); print("Size of:"); - print($"- Training-set:\t\t{len(mnist.train.data)}"); - print($"- Validation-set:\t{len(mnist.validation.data)}"); + print($"- Training-set:\t\t{len(mnist.Train.Data)}"); + print($"- Validation-set:\t{len(mnist.Validation.Data)}"); } /// diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs index 09fdc818..059c5419 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs @@ -17,7 +17,7 @@ using NumSharp; using System; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples.ImageProcess @@ -44,7 +44,7 @@ namespace TensorFlowNET.Examples.ImageProcess int batch_size = 100; float learning_rate = 0.001f; int h1 = 200; // number of nodes in the 1st hidden layer - Datasets mnist; + Datasets mnist; Tensor x, y; Tensor loss, accuracy; @@ -121,13 +121,13 @@ namespace TensorFlowNET.Examples.ImageProcess public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true); + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; } public void Train(Session sess) { // Number of training iterations in each epoch - var num_tr_iter = mnist.train.labels.len / batch_size; + var num_tr_iter = mnist.Train.Labels.len / batch_size; var init = tf.global_variables_initializer(); sess.run(init); @@ -139,13 +139,13 @@ namespace TensorFlowNET.Examples.ImageProcess { print($"Training epoch: {epoch + 1}"); // Randomly shuffle the training data at the beginning of each epoch - var (x_train, y_train) = randomize(mnist.train.data, mnist.train.labels); + var (x_train, y_train) = mnist.Randomize(mnist.Train.Data, mnist.Train.Labels); foreach (var iteration in range(num_tr_iter)) { var start = iteration * batch_size; var end = (iteration + 1) * batch_size; - var (x_batch, y_batch) = get_next_batch(x_train, y_train, start, end); + var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); // Run optimization op (backprop) sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); @@ -161,7 +161,8 @@ namespace TensorFlowNET.Examples.ImageProcess } // Run validation after every epoch - var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels)); + var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Validation.Data), new FeedItem(y, mnist.Validation.Labels)); + loss_val = results1[0]; accuracy_val = results1[1]; print("---------------------------------------------------------"); @@ -172,35 +173,12 @@ namespace TensorFlowNET.Examples.ImageProcess public void Test(Session sess) { - var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels)); + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); loss_test = result[0]; accuracy_test = result[1]; print("---------------------------------------------------------"); print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); print("---------------------------------------------------------"); } - - private (NDArray, NDArray) randomize(NDArray x, NDArray y) - { - var perm = np.random.permutation(y.shape[0]); - - np.random.shuffle(perm); - return (mnist.train.data[perm], mnist.train.labels[perm]); - } - - /// - /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) - /// - /// - /// - /// - /// - /// - private (NDArray, NDArray) get_next_batch(NDArray x, NDArray y, int start, int end) - { - var x_batch = x[$"{start}:{end}"]; - var y_batch = y[$"{start}:{end}"]; - return (x_batch, y_batch); - } } } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs index d51ca9ad..babf62f3 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs @@ -17,7 +17,7 @@ using NumSharp; using System; using Tensorflow; -using TensorFlowNET.Examples.Utility; +using Tensorflow.Hub; using static Tensorflow.Python; namespace TensorFlowNET.Examples.ImageProcess @@ -45,7 +45,7 @@ namespace TensorFlowNET.Examples.ImageProcess int n_inputs = 28; int n_outputs = 10; - Datasets mnist; + Datasets mnist; Tensor x, y; Tensor loss, accuracy, cls_prediction; @@ -143,15 +143,15 @@ namespace TensorFlowNET.Examples.ImageProcess public void PrepareData() { - mnist = MNIST.read_data_sets("mnist", one_hot: true); - (x_train, y_train) = (mnist.train.data, mnist.train.labels); - (x_valid, y_valid) = (mnist.validation.data, mnist.validation.labels); - (x_test, y_test) = (mnist.test.data, mnist.test.labels); + mnist = MnistModelLoader.LoadAsync(".resources/mnist", oneHot: true).Result; + (x_train, y_train) = (mnist.Train.Data, mnist.Train.Labels); + (x_valid, y_valid) = (mnist.Validation.Data, mnist.Validation.Labels); + (x_test, y_test) = (mnist.Test.Data, mnist.Test.Labels); print("Size of:"); - print($"- Training-set:\t\t{len(mnist.train.data)}"); - print($"- Validation-set:\t{len(mnist.validation.data)}"); - print($"- Test-set:\t\t{len(mnist.test.data)}"); + print($"- Training-set:\t\t{len(mnist.Train.Data)}"); + print($"- Validation-set:\t{len(mnist.Validation.Data)}"); + print($"- Test-set:\t\t{len(mnist.Test.Data)}"); } public Graph ImportGraph() => throw new NotImplementedException(); diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 149bd549..6184d4ad 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -18,5 +18,6 @@ + diff --git a/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs b/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs deleted file mode 100644 index 0017eba5..00000000 --- a/test/TensorFlowNET.Examples/Utility/DataSetMnist.cs +++ /dev/null @@ -1,95 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using Tensorflow; - -namespace TensorFlowNET.Examples.Utility -{ - public class DataSetMnist : IDataSet - { - public int num_examples { get; } - - public int epochs_completed { get; private set; } - public int index_in_epoch { get; private set; } - public NDArray data { get; private set; } - public NDArray labels { get; private set; } - - public DataSetMnist(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) - { - num_examples = images.shape[0]; - images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); - images.astype(dtype.as_numpy_datatype()); - images = np.multiply(images, 1.0f / 255.0f); - - labels.astype(dtype.as_numpy_datatype()); - - data = images; - this.labels = labels; - epochs_completed = 0; - index_in_epoch = 0; - } - - public (NDArray, NDArray) next_batch(int batch_size, bool fake_data = false, bool shuffle = true) - { - var start = index_in_epoch; - // Shuffle for the first epoch - if(epochs_completed == 0 && start == 0 && shuffle) - { - var perm0 = np.arange(num_examples); - np.random.shuffle(perm0); - data = data[perm0]; - labels = labels[perm0]; - } - - // Go to the next epoch - if (start + batch_size > num_examples) - { - // Finished epoch - epochs_completed += 1; - - // Get the rest examples in this epoch - var rest_num_examples = num_examples - start; - //var images_rest_part = _images[np.arange(start, _num_examples)]; - //var labels_rest_part = _labels[np.arange(start, _num_examples)]; - // Shuffle the data - if (shuffle) - { - var perm = np.arange(num_examples); - np.random.shuffle(perm); - data = data[perm]; - labels = labels[perm]; - } - - start = 0; - index_in_epoch = batch_size - rest_num_examples; - var end = index_in_epoch; - var images_new_part = data[np.arange(start, end)]; - var labels_new_part = labels[np.arange(start, end)]; - - /*return (np.concatenate(new float[][] { images_rest_part.Data(), images_new_part.Data() }, axis: 0), - np.concatenate(new float[][] { labels_rest_part.Data(), labels_new_part.Data() }, axis: 0));*/ - return (images_new_part, labels_new_part); - } - else - { - index_in_epoch += batch_size; - var end = index_in_epoch; - return (data[np.arange(start, end)], labels[np.arange(start, end)]); - } - } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/Datasets.cs b/test/TensorFlowNET.Examples/Utility/Datasets.cs deleted file mode 100644 index 0c8c4e2d..00000000 --- a/test/TensorFlowNET.Examples/Utility/Datasets.cs +++ /dev/null @@ -1,46 +0,0 @@ -using NumSharp; - -namespace TensorFlowNET.Examples.Utility -{ - public class Datasets where T : IDataSet - { - private T _train; - public T train => _train; - - private T _validation; - public T validation => _validation; - - private T _test; - public T test => _test; - - public Datasets(T train, T validation, T test) - { - _train = train; - _validation = validation; - _test = test; - } - - public (NDArray, NDArray) Randomize(NDArray x, NDArray y) - { - var perm = np.random.permutation(y.shape[0]); - - np.random.shuffle(perm); - return (x[perm], y[perm]); - } - - /// - /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) - /// - /// - /// - /// - /// - /// - public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) - { - var x_batch = x[$"{start}:{end}"]; - var y_batch = y[$"{start}:{end}"]; - return (x_batch, y_batch); - } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/IDataSet.cs b/test/TensorFlowNET.Examples/Utility/IDataSet.cs deleted file mode 100644 index 31be57c1..00000000 --- a/test/TensorFlowNET.Examples/Utility/IDataSet.cs +++ /dev/null @@ -1,10 +0,0 @@ -using NumSharp; - -namespace TensorFlowNET.Examples.Utility -{ - public interface IDataSet - { - NDArray data { get; } - NDArray labels { get; } - } -} diff --git a/test/TensorFlowNET.Examples/Utility/MNIST.cs b/test/TensorFlowNET.Examples/Utility/MNIST.cs deleted file mode 100644 index 73d6fe2a..00000000 --- a/test/TensorFlowNET.Examples/Utility/MNIST.cs +++ /dev/null @@ -1,131 +0,0 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using System.IO; -using Tensorflow; - -namespace TensorFlowNET.Examples.Utility -{ - public class MNIST - { - private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"; - private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz"; - private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz"; - private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz"; - private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz"; - public static Datasets read_data_sets(string train_dir, - bool one_hot = false, - TF_DataType dtype = TF_DataType.TF_FLOAT, - bool reshape = true, - int validation_size = 5000, - int? train_size = null, - int? test_size = null, - string source_url = DEFAULT_SOURCE_URL) - { - if (train_size!=null && validation_size >= train_size) - throw new ArgumentException("Validation set should be smaller than training set"); - - Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES); - Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir); - var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]), limit: train_size); - - Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS); - Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir); - var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot, limit: train_size); - - Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES); - Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir); - var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]), limit: test_size); - - Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS); - Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir); - var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot, limit:test_size); - - int end = train_images.shape[0]; - var validation_images = train_images[np.arange(validation_size)]; - var validation_labels = train_labels[np.arange(validation_size)]; - train_images = train_images[np.arange(validation_size, end)]; - train_labels = train_labels[np.arange(validation_size, end)]; - - var train = new DataSetMnist(train_images, train_labels, dtype, reshape); - var validation = new DataSetMnist(validation_images, validation_labels, dtype, reshape); - var test = new DataSetMnist(test_images, test_labels, dtype, reshape); - - return new Datasets(train, validation, test); - } - - public static NDArray extract_images(string file, int? limit=null) - { - using (var bytestream = new FileStream(file, FileMode.Open)) - { - var magic = _read32(bytestream); - if (magic != 2051) - throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}"); - var num_images = _read32(bytestream); - num_images = limit == null ? num_images : Math.Min(num_images, (uint)limit); - var rows = _read32(bytestream); - var cols = _read32(bytestream); - var buf = new byte[rows * cols * num_images]; - bytestream.Read(buf, 0, buf.Length); - var data = np.frombuffer(buf, np.uint8); - data = data.reshape((int)num_images, (int)rows, (int)cols, 1); - return data; - } - } - - public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10, int? limit = null) - { - using (var bytestream = new FileStream(file, FileMode.Open)) - { - var magic = _read32(bytestream); - if (magic != 2049) - throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}"); - var num_items = _read32(bytestream); - num_items = limit == null ? num_items : Math.Min(num_items,(uint) limit); - var buf = new byte[num_items]; - bytestream.Read(buf, 0, buf.Length); - var labels = np.frombuffer(buf, np.uint8); - if (one_hot) - return dense_to_one_hot(labels, num_classes); - return labels; - } - } - - private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes) - { - var num_labels = labels_dense.shape[0]; - var index_offset = np.arange(num_labels) * num_classes; - var labels_one_hot = np.zeros(num_labels, num_classes); - - for(int row = 0; row < num_labels; row++) - { - var col = labels_dense.Data(row); - labels_one_hot.SetData(1.0, row, col); - } - - return labels_one_hot; - } - - private static uint _read32(FileStream bytestream) - { - var buffer = new byte[sizeof(uint)]; - var count = bytestream.Read(buffer, 0, 4); - return np.frombuffer(buffer, ">u4").Data(0); - } - } -} From cffc950509495c8687d0173bd70944d38de514df Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 31 Jul 2019 06:59:49 -0500 Subject: [PATCH 02/11] Add TensorFlowDatasets project. --- TensorFlow.NET.sln | 8 +- src/TensorFlowDatasets/DatasetBuilder.cs | 24 ++++++ src/TensorFlowDatasets/DownloadConfig.cs | 10 +++ .../TensorFlowDatasets.csproj | 19 +++++ .../ImageProcessing/CIFAR10-CNN.cs | 74 +++++++++++++++++++ .../ImageProcessing/DigitRecognitionCNN.cs | 2 +- .../ImageProcessing/DigitRecognitionNN.cs | 2 +- .../ImageProcessing/DigitRecognitionRNN.cs | 2 +- .../ImageProcessing/ImageBackgroundRemoval.cs | 2 +- .../ImageProcessing/RetrainImageClassifier.cs | 2 +- .../TensorFlowNET.Examples.csproj | 1 + 11 files changed, 140 insertions(+), 6 deletions(-) create mode 100644 src/TensorFlowDatasets/DatasetBuilder.cs create mode 100644 src/TensorFlowDatasets/DownloadConfig.cs create mode 100644 src/TensorFlowDatasets/TensorFlowDatasets.csproj create mode 100644 test/TensorFlowNET.Examples/ImageProcessing/CIFAR10-CNN.cs diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 5d6e5fe7..542d09c1 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -17,7 +17,9 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowBenchmark", "src\ EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowHub", "src\TensorFlowHub\TensorFlowHub.csproj", "{8FD59A5A-97EB-457E-B9F1-D88B0C822C6E}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorFlowText", "src\TensorFlowText\TensorFlowText.csproj", "{B598E5D5-BD2D-4191-8532-F2FBAC31AB81}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowText", "src\TensorFlowText\TensorFlowText.csproj", "{B598E5D5-BD2D-4191-8532-F2FBAC31AB81}" +EndProject +Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorFlowDatasets", "src\TensorFlowDatasets\TensorFlowDatasets.csproj", "{DF151A51-E9FD-41BD-B0F4-08A743755D44}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -57,6 +59,10 @@ Global {B598E5D5-BD2D-4191-8532-F2FBAC31AB81}.Debug|Any CPU.Build.0 = Debug|Any CPU {B598E5D5-BD2D-4191-8532-F2FBAC31AB81}.Release|Any CPU.ActiveCfg = Release|Any CPU {B598E5D5-BD2D-4191-8532-F2FBAC31AB81}.Release|Any CPU.Build.0 = Release|Any CPU + {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Debug|Any CPU.Build.0 = Debug|Any CPU + {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Release|Any CPU.ActiveCfg = Release|Any CPU + {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/src/TensorFlowDatasets/DatasetBuilder.cs b/src/TensorFlowDatasets/DatasetBuilder.cs new file mode 100644 index 00000000..bfb78d6e --- /dev/null +++ b/src/TensorFlowDatasets/DatasetBuilder.cs @@ -0,0 +1,24 @@ +using System; + +namespace TensorFlowDatasets +{ + /// + /// Abstract base class for all datasets. + /// + public class DatasetBuilder + { + /// + /// Downloads and prepares dataset for reading. + /// + /// + /// directory where downloaded files are stored. + /// + /// + /// further configuration for downloading and preparing dataset. + /// + public void download_and_prepare(string download_dir = null, DownloadConfig download_config = null) + { + + } + } +} diff --git a/src/TensorFlowDatasets/DownloadConfig.cs b/src/TensorFlowDatasets/DownloadConfig.cs new file mode 100644 index 00000000..0488e273 --- /dev/null +++ b/src/TensorFlowDatasets/DownloadConfig.cs @@ -0,0 +1,10 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace TensorFlowDatasets +{ + public class DownloadConfig + { + } +} diff --git a/src/TensorFlowDatasets/TensorFlowDatasets.csproj b/src/TensorFlowDatasets/TensorFlowDatasets.csproj new file mode 100644 index 00000000..1b839c1f --- /dev/null +++ b/src/TensorFlowDatasets/TensorFlowDatasets.csproj @@ -0,0 +1,19 @@ + + + + netcoreapp2.2 + SciSharp.TensorFlowDatasets + 0.0.1 + SciSharp Team + TensorFlow Datasets + true + https://avatars3.githubusercontent.com/u/44989469?s=200&v=4 + http://scisharpstack.org + TensorFlow Datasets provides many public datasets as tf.data.Datasets. + https://github.com/SciSharp/TensorFlow.NET + git + SciSharp, Dataset, TensorFlow + Apache 2.0 + + + diff --git a/test/TensorFlowNET.Examples/ImageProcessing/CIFAR10-CNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/CIFAR10-CNN.cs new file mode 100644 index 00000000..a77a5b00 --- /dev/null +++ b/test/TensorFlowNET.Examples/ImageProcessing/CIFAR10-CNN.cs @@ -0,0 +1,74 @@ +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow; +using TensorFlowDatasets; +using static Tensorflow.Python; + +namespace TensorFlowNET.Examples +{ + /// + /// https://www.tensorflow.org/tutorials/images/deep_cnn + /// + public class CIFAR10_CNN : IExample + { + public bool Enabled { get; set; } = true; + public bool IsImportingGraph { get; set; } = false; + + public string Name => "CIFAR-10 CNN"; + + public bool Run() + { + PrepareData(); + + return true; + } + + public Graph BuildGraph() + { + throw new NotImplementedException(); + } + + public Graph ImportGraph() + { + throw new NotImplementedException(); + } + + public void Predict(Session sess) + { + throw new NotImplementedException(); + } + + public void PrepareData() + { + var tfds = new DatasetBuilder(); + tfds.download_and_prepare(); + } + + public void Test(Session sess) + { + throw new NotImplementedException(); + } + + public void Train(Session sess) + { + throw new NotImplementedException(); + } + } +} diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs index d2a1b9f4..ac763da2 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs @@ -21,7 +21,7 @@ using Tensorflow; using Tensorflow.Hub; using static Tensorflow.Python; -namespace TensorFlowNET.Examples.ImageProcess +namespace TensorFlowNET.Examples { /// /// Convolutional Neural Network classifier for Hand Written Digits diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs index 059c5419..e604afff 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs @@ -20,7 +20,7 @@ using Tensorflow; using Tensorflow.Hub; using static Tensorflow.Python; -namespace TensorFlowNET.Examples.ImageProcess +namespace TensorFlowNET.Examples { /// /// Neural Network classifier for Hand Written Digits diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs index babf62f3..f769371a 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs @@ -20,7 +20,7 @@ using Tensorflow; using Tensorflow.Hub; using static Tensorflow.Python; -namespace TensorFlowNET.Examples.ImageProcess +namespace TensorFlowNET.Examples { /// /// Recurrent Neural Network for handwritten digits MNIST. diff --git a/test/TensorFlowNET.Examples/ImageProcessing/ImageBackgroundRemoval.cs b/test/TensorFlowNET.Examples/ImageProcessing/ImageBackgroundRemoval.cs index c43c853a..db148f14 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/ImageBackgroundRemoval.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/ImageBackgroundRemoval.cs @@ -4,7 +4,7 @@ using Tensorflow; using TensorFlowNET.Examples.Utility; using static Tensorflow.Python; -namespace TensorFlowNET.Examples.ImageProcess +namespace TensorFlowNET.Examples { /// /// This example removes the background from an input image. diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs index 4d3a858f..96da8d1c 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs @@ -25,7 +25,7 @@ using Tensorflow; using TensorFlowNET.Examples.Utility; using static Tensorflow.Python; -namespace TensorFlowNET.Examples.ImageProcess +namespace TensorFlowNET.Examples { /// /// In this tutorial, we will reuse the feature extraction capabilities from powerful image classifiers trained on ImageNet diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj index 6184d4ad..f4e2340a 100644 --- a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj @@ -16,6 +16,7 @@ + From c938fa895ffbbff2969e7125dbb81232bbd20358 Mon Sep 17 00:00:00 2001 From: Antonio Date: Wed, 31 Jul 2019 21:43:12 +0200 Subject: [PATCH 03/11] Add missing `operator >=`s and `operator <=`s (#331) * Add missing `operator >=`s Also unit testing the new operators. * Add missing `operator <=`s Also unit testing all the operator cases. --- .../Tensors/Tensor.Operators.cs | 18 +- test/TensorFlowNET.UnitTest/OperationsTest.cs | 308 ++++++++++++++++++ 2 files changed, 324 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index a5a9b674..de26e28b 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -87,7 +87,6 @@ namespace Tensorflow public static Tensor operator >(int x, Tensor y) => gen_math_ops.greater(x, y); public static Tensor operator >(Tensor x, Tensor y) => gen_math_ops.greater(x, y); public static Tensor operator >(Tensor x, int y) => gen_math_ops.greater(x, y); - public static Tensor operator >=(Tensor x, Tensor y) => gen_math_ops.greater_equal(x, y); public static Tensor operator >(Tensor x, float y) => gen_math_ops.greater(x, y); public static Tensor operator >(Tensor x, double y) => gen_math_ops.greater(x, y); @@ -96,10 +95,25 @@ namespace Tensorflow public static Tensor operator <(int x, Tensor y) => gen_math_ops.less(x, y); public static Tensor operator <(Tensor x, Tensor y) => gen_math_ops.less(x, y); public static Tensor operator <(Tensor x, int y) => gen_math_ops.less(x, y); - public static Tensor operator <=(Tensor x, Tensor y) => gen_math_ops.less_equal(x, y); public static Tensor operator <(Tensor x, float y) => gen_math_ops.less(x, y); public static Tensor operator <(Tensor x, double y) => gen_math_ops.less(x, y); + public static Tensor operator >=(double x, Tensor y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(float x, Tensor y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(int x, Tensor y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(Tensor x, Tensor y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(Tensor x, int y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(Tensor x, float y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(Tensor x, double y) => gen_math_ops.greater_equal(x, y); + + public static Tensor operator <=(int x, Tensor y) => gen_math_ops.less_equal(x, y); + public static Tensor operator <=(float x, Tensor y) => gen_math_ops.less_equal(x, y); + public static Tensor operator <=(double x, Tensor y) => gen_math_ops.less_equal(x, y); + public static Tensor operator <=(Tensor x, Tensor y) => gen_math_ops.less_equal(x, y); + public static Tensor operator <=(Tensor x, int y) => gen_math_ops.less_equal(x, y); + public static Tensor operator <=(Tensor x, float y) => gen_math_ops.less_equal(x, y); + public static Tensor operator <=(Tensor x, double y) => gen_math_ops.less_equal(x, y); + private static Tensor BinaryOpWrapper(string name, Tx x, Ty y) { TF_DataType dtype = TF_DataType.DtInvalid; diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 358f3fb9..a0a3b5e4 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -1141,5 +1141,313 @@ namespace TensorFlowNET.UnitTest } #endregion } + + [TestMethod] + public void greaterOrEqualThanOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int intThreshold = 10; + + var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray(); + var intResult = firstIntFeed.Count(elem => elem >= intThreshold); + var intResultTwo = firstIntFeed.Count(elem => elem <= intThreshold); + + var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater_equal(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator >=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator >=(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= intThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator >=(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold >= a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResultTwo); + } + #endregion + + #region floatTest + const float floatThreshold = 10.0f; + + var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray(); + var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Count(elem => elem >= floatThreshold); + var floatResultTwo = firstFloatFeed.Count(elem => elem <= floatThreshold); + + a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater_equal(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator >=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator >=(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= floatThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator >=(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold >= a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResultTwo); + } + #endregion + + #region doubleTest + const double doubleThreshold = 10.0; + + var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Count(elem => elem >= doubleThreshold); + var doubleResultTwo = firstDoubleFeed.Count(elem => elem <= doubleThreshold); + + a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater_equal(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator >=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator >=(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= doubleThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator >=(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold >= a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResultTwo); + } + #endregion + } + + [TestMethod] + public void lessOrEqualThanOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int intThreshold = 10; + + var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray(); + var intResult = firstIntFeed.Count(elem => elem <= intThreshold); + var intResultTwo = firstIntFeed.Count(elem => elem >= intThreshold); + + var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less_equal(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator <=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator <=(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= intThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator <=(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold <= a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResultTwo); + } + #endregion + + #region floatTest + const float floatThreshold = 10.0f; + + var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray(); + var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Count(elem => elem <= floatThreshold); + var floatResultTwo = firstFloatFeed.Count(elem => elem >= floatThreshold); + + a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less_equal(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator <=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator <=(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= floatThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator <=(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold <= a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResultTwo); + } + #endregion + + #region doubleTest + const double doubleThreshold = 10.0; + + var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Count(elem => elem <= doubleThreshold); + var doubleResultTwo = firstDoubleFeed.Count(elem => elem >= doubleThreshold); + + a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.less_equal(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator <=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator <=(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a <= doubleThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator <=(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold <= a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResultTwo); + } + #endregion + } } } From 83b62e494d724d699ab0f2205818238b9aeee949 Mon Sep 17 00:00:00 2001 From: Kerry Jiang Date: Wed, 31 Jul 2019 22:10:20 -0700 Subject: [PATCH 04/11] fixed a bug about loading datasets --- src/TensorFlowHub/MnistModelLoader.cs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/TensorFlowHub/MnistModelLoader.cs b/src/TensorFlowHub/MnistModelLoader.cs index b22d359e..121c0961 100644 --- a/src/TensorFlowHub/MnistModelLoader.cs +++ b/src/TensorFlowHub/MnistModelLoader.cs @@ -22,8 +22,7 @@ namespace Tensorflow.Hub var setting = new ModelLoadSetting { TrainDir = trainDir, - OneHot = oneHot, - TrainSize = trainSize + OneHot = oneHot }; if (trainSize.HasValue) @@ -99,7 +98,7 @@ namespace Tensorflow.Hub var train = new MnistDataSet(trainImages, trainLabels, dtype, reshape); var validation = new MnistDataSet(validationImages, validationLabels, dtype, reshape); - var test = new MnistDataSet(trainImages, trainLabels, dtype, reshape); + var test = new MnistDataSet(testImages, testLabels, dtype, reshape); return new Datasets(train, validation, test); } From eb3b56ee90a5f70c04d44666c360fbefe46bd2a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=85=E6=B0=B8?= Date: Thu, 1 Aug 2019 18:09:16 +0800 Subject: [PATCH 06/11] =?UTF-8?q?fix=20#333=20:=20Add=20a=20project=20file?= =?UTF-8?q?=20for=20GPU=20examples=20=E3=80=82=20(#336)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- TensorFlow.NET.sln | 6 +++++ .../TensorFlowNET.Examples.GPU.csproj | 24 +++++++++++++++++++ 2 files changed, 30 insertions(+) create mode 100644 test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln index 542d09c1..689965c4 100644 --- a/TensorFlow.NET.sln +++ b/TensorFlow.NET.sln @@ -21,6 +21,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowText", "src\Tenso EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "TensorFlowDatasets", "src\TensorFlowDatasets\TensorFlowDatasets.csproj", "{DF151A51-E9FD-41BD-B0F4-08A743755D44}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples.GPU", "test\TensorFlowNET.Examples\TensorFlowNET.Examples.GPU.csproj", "{6F6B3382-8F87-4CD9-BF87-C81D5405685A}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -63,6 +65,10 @@ Global {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Debug|Any CPU.Build.0 = Debug|Any CPU {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Release|Any CPU.ActiveCfg = Release|Any CPU {DF151A51-E9FD-41BD-B0F4-08A743755D44}.Release|Any CPU.Build.0 = Release|Any CPU + {6F6B3382-8F87-4CD9-BF87-C81D5405685A}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6F6B3382-8F87-4CD9-BF87-C81D5405685A}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6F6B3382-8F87-4CD9-BF87-C81D5405685A}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6F6B3382-8F87-4CD9-BF87-C81D5405685A}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE diff --git a/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj new file mode 100644 index 00000000..1bd3d530 --- /dev/null +++ b/test/TensorFlowNET.Examples/TensorFlowNET.Examples.GPU.csproj @@ -0,0 +1,24 @@ + + + + Exe + netcoreapp2.2 + false + + + + + + + + + + + + + + + + + + From ae22b6d2b2c7dc1250fd83bb74b5a86ad636716c Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Wed, 31 Jul 2019 23:50:57 -0500 Subject: [PATCH 07/11] Fixed Session.LoadFromSavedModel(), Found accuracy regression for Logistic Regression. --- src/TensorFlowNET.Core/Sessions/Session.cs | 11 +++++++++-- src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 10 ++++++---- .../BasicModels/LogisticRegression.cs | 2 +- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 21c4de09..374a57ad 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -50,9 +50,16 @@ namespace Tensorflow var graph = c_api.TF_NewGraph(); var status = new Status(); var opt = c_api.TF_NewSessionOptions(); - + var tags = new string[] { "serve" }; var buffer = new TF_Buffer(); - var sess = c_api.TF_LoadSessionFromSavedModel(opt, IntPtr.Zero, path, new string[0], 0, graph, ref buffer, status); + var sess = c_api.TF_LoadSessionFromSavedModel(opt, + IntPtr.Zero, + path, + tags, + tags.Length, + graph, + ref buffer, + status); //var bytes = new Buffer(buffer.data).Data; //var meta_graph = MetaGraphDef.Parser.ParseFrom(bytes); diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 1ec4f6f3..7374f82f 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.10.3 + 0.10.4 Haiping Chen, Meinrad Recheis SciSharp STACK true @@ -17,7 +17,7 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.10.3.0 + 0.10.4.0 Changes since v0.9.0: 1. Added full connected Convolution Neural Network example. @@ -29,9 +29,11 @@ Docs: https://tensorflownet.readthedocs.io 7. Add BatchMatMulGrad. 8. Upgrade NumSharp. 9. Fix strided_slice_grad type convention error. -10. Add AbsGrad. +10. Add AbsGrad. +11. Fix Session.LoadFromSavedModel(string). +12. Add Tensor operator overloads. 7.2 - 0.10.3.0 + 0.10.4.0 LICENSE true true diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index 1d7808b7..a627c517 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); print($"Accuracy: {acc.ToString("F4")}"); - return acc > 0.9; + return acc > 0.88; }); } From 29265c55ee965a23ff7d04ce5f9d9d9885ae7f48 Mon Sep 17 00:00:00 2001 From: Sattisvar TANDABANY Date: Thu, 1 Aug 2019 20:59:52 +0200 Subject: [PATCH 08/11] fix memory leak due to wrong handle sent to api (#339) A simple typo that led to a memory leak. --- src/TensorFlowNET.Core/Tensors/Tensor.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index aebca212..54b58122 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -363,7 +363,7 @@ namespace Tensorflow _handle=IntPtr.Zero; } if (h != IntPtr.Zero) - c_api.TF_DeleteTensor(_handle); + c_api.TF_DeleteTensor(h); status.Dispose(); GC.SuppressFinalize(this); } From a8aac8d0b8a77850f9d58518a5f6a9352b002d77 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 1 Aug 2019 09:18:05 -0500 Subject: [PATCH 09/11] fix default graph and operation issue when import model. --- src/TensorFlowNET.Core/Buffers/Buffer.cs | 4 +- .../Framework/c_api_util.cs | 2 +- .../Graphs/Graph.Operation.cs | 25 +++++++ src/TensorFlowNET.Core/Graphs/Graph.cs | 71 +++++++++++-------- src/TensorFlowNET.Core/Graphs/c_api.graph.cs | 4 ++ .../Operations/Operation.Implicit.cs | 5 +- .../Operations/Operation.Output.cs | 2 + src/TensorFlowNET.Core/Sessions/Session.cs | 18 ++--- .../BasicModels/LogisticRegression.cs | 2 +- 9 files changed, 89 insertions(+), 44 deletions(-) diff --git a/src/TensorFlowNET.Core/Buffers/Buffer.cs b/src/TensorFlowNET.Core/Buffers/Buffer.cs index 0b73265d..378c7c85 100644 --- a/src/TensorFlowNET.Core/Buffers/Buffer.cs +++ b/src/TensorFlowNET.Core/Buffers/Buffer.cs @@ -30,8 +30,8 @@ namespace Tensorflow get { var data = new byte[buffer.length]; - if (buffer.length > 0) - Marshal.Copy(buffer.data, data, 0, (int)buffer.length); + if (data.Length > 0) + Marshal.Copy(buffer.data, data, 0, data.Length); return data; } } diff --git a/src/TensorFlowNET.Core/Framework/c_api_util.cs b/src/TensorFlowNET.Core/Framework/c_api_util.cs index 440cbf44..5d5cb9b3 100644 --- a/src/TensorFlowNET.Core/Framework/c_api_util.cs +++ b/src/TensorFlowNET.Core/Framework/c_api_util.cs @@ -128,7 +128,7 @@ namespace Tensorflow IntPtr c_op; while ((c_op = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero) { - yield return c_op; + yield return new Operation(c_op, graph); } } } diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs index 06b65f03..09e09573 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Operation.cs @@ -38,6 +38,31 @@ namespace Tensorflow return c_api.TF_NewOperation(_handle, opType, opName); } + public unsafe Operation[] ReturnOperations(IntPtr results) + { + TF_Operation return_oper_handle = new TF_Operation(); + int num_return_opers = 0; + c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); + Operation[] return_opers = new Operation[num_return_opers]; + for (int i = 0; i < num_return_opers; i++) + { + var handle = return_oper_handle.node + Marshal.SizeOf() * i; + return_opers[i] = new Operation(*(IntPtr*)handle); + } + + return return_opers; + } + + public Operation OperationByName(string operName) + { + return c_api.TF_GraphOperationByName(_handle, operName); + } + + public ITensorOrOperation[] get_operations() + { + return _nodes_by_name.Values.Select(x => x).ToArray(); + } + /// /// Returns the `Operation` with the given `name`. /// diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 82e83df1..08ed95af 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; @@ -72,7 +73,7 @@ namespace Tensorflow all variables that are created during the construction of a graph. The caller may define additional collections by specifying a new name. */ - public partial class Graph : IPython, IDisposable + public partial class Graph : IPython, IDisposable, IEnumerable { private IntPtr _handle; private Dictionary _nodes_by_id; @@ -121,6 +122,10 @@ namespace Tensorflow _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); _graph_key = $"grap-key-{ops.uid()}/"; + } + + public void __enter__() + { } public ITensorOrOperation as_graph_element(object obj, bool allow_tensor = true, bool allow_operation = true) @@ -409,31 +414,6 @@ namespace Tensorflow return return_outputs; } - public unsafe Operation[] ReturnOperations(IntPtr results) - { - TF_Operation return_oper_handle = new TF_Operation(); - int num_return_opers = 0; - c_api.TF_ImportGraphDefResultsReturnOperations(results, ref num_return_opers, ref return_oper_handle); - Operation[] return_opers = new Operation[num_return_opers]; - for (int i = 0; i < num_return_opers; i++) - { - var handle = return_oper_handle.node + Marshal.SizeOf() * i; - return_opers[i] = new Operation(*(IntPtr*)handle); - } - - return return_opers; - } - - public Operation OperationByName(string operName) - { - return c_api.TF_GraphOperationByName(_handle, operName); - } - - public ITensorOrOperation[] get_operations() - { - return _nodes_by_name.Values.Select(x => x).ToArray(); - } - public string[] get_all_collection_keys() { return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); @@ -481,17 +461,46 @@ namespace Tensorflow public Tensor get_tensor_by_name(string name) { return (Tensor)this.as_graph_element(name, allow_tensor: true, allow_operation: false); - } - - public void __enter__() - { + } + + public TensorShape GetTensorShape(TF_Output output) + { + var status = new Status(); + var ndim = c_api.TF_GraphGetTensorNumDims(_handle, output, status); + status.Check(); + + if (ndim == -1) + return new TensorShape(); + + var dims = new long[ndim]; + c_api.TF_GraphGetTensorShape(_handle, output, dims, dims.Length, status); + status.Check(); + + return new TensorShape(dims.Select(x => (int)x).ToArray()); + } + + public override string ToString() + { + int len = 0; + return c_api.TF_GraphDebugString(_handle, out len); } public void __exit__() { - } + } + + private IEnumerable GetEnumerable() + => c_api_util.tf_operations(this); + IEnumerator IEnumerable.GetEnumerator() + => GetEnumerable().GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() + { + throw new NotImplementedException(); + } + public static implicit operator IntPtr(Graph graph) { return graph._handle; diff --git a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs index 05cd5940..889949ef 100644 --- a/src/TensorFlowNET.Core/Graphs/c_api.graph.cs +++ b/src/TensorFlowNET.Core/Graphs/c_api.graph.cs @@ -43,6 +43,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern void TF_DeleteImportGraphDefResults(IntPtr results); + [DllImport(TensorFlowLibName)] + public static extern string TF_GraphDebugString(IntPtr graph, out int len); + [DllImport(TensorFlowLibName)] public static extern void TF_GraphGetOpDef(IntPtr graph, string op_name, IntPtr output_op_def, IntPtr status); @@ -100,6 +103,7 @@ namespace Tensorflow /// TF_Status* [DllImport(TensorFlowLibName)] public static extern void TF_GraphImportGraphDef(IntPtr graph, IntPtr graph_def, IntPtr options, IntPtr status); + /// /// Iterate through the operations of a graph. /// diff --git a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs index 1b99dcc8..8de412c8 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs @@ -23,7 +23,10 @@ namespace Tensorflow /// public partial class Operation { - public static implicit operator Operation(IntPtr handle) => new Operation(handle); + // make sure the new op is in the same graph instance + public static implicit operator Operation(IntPtr handle) + => new Operation(handle); + public static implicit operator IntPtr(Operation op) => op._handle; public static implicit operator Tensor(Operation op) => op.output; diff --git a/src/TensorFlowNET.Core/Operations/Operation.Output.cs b/src/TensorFlowNET.Core/Operations/Operation.Output.cs index cefb76cf..41f4a332 100644 --- a/src/TensorFlowNET.Core/Operations/Operation.Output.cs +++ b/src/TensorFlowNET.Core/Operations/Operation.Output.cs @@ -35,6 +35,8 @@ namespace Tensorflow public int OutputNumConsumers(int index) => c_api.TF_OperationOutputNumConsumers(new TF_Output(_handle, index)); + public TF_Output this[int index] => _tf_output(index); + public unsafe TF_Input[] OutputConsumers(int index, int max_consumers) { int size = Marshal.SizeOf(); diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 374a57ad..191730d0 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -15,6 +15,7 @@ ******************************************************************************/ using System; +using System.Runtime.InteropServices; namespace Tensorflow { @@ -26,8 +27,8 @@ namespace Tensorflow } - public Session(IntPtr handle) - : base("", null, null) + public Session(IntPtr handle, Graph g = null) + : base("", g, null) { _session = handle; } @@ -50,8 +51,10 @@ namespace Tensorflow var graph = c_api.TF_NewGraph(); var status = new Status(); var opt = c_api.TF_NewSessionOptions(); + var tags = new string[] { "serve" }; var buffer = new TF_Buffer(); + var sess = c_api.TF_LoadSessionFromSavedModel(opt, IntPtr.Zero, path, @@ -61,14 +64,13 @@ namespace Tensorflow ref buffer, status); - //var bytes = new Buffer(buffer.data).Data; - //var meta_graph = MetaGraphDef.Parser.ParseFrom(bytes); - + // load graph bytes + // var data = new byte[buffer.length]; + // Marshal.Copy(buffer.data, data, 0, (int)buffer.length); + // var meta_graph = MetaGraphDef.Parser.ParseFrom(data);*/ status.Check(); - new Graph(graph).as_default(); - - return sess; + return new Session(sess, g: new Graph(graph).as_default()); } public static implicit operator IntPtr(Session session) => session._session; diff --git a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs index a627c517..1d7808b7 100644 --- a/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs +++ b/test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs @@ -118,7 +118,7 @@ namespace TensorFlowNET.Examples float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); print($"Accuracy: {acc.ToString("F4")}"); - return acc > 0.88; + return acc > 0.9; }); } From 33d82e8817b31950e054c17dfb425899a2fd5bd2 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 1 Aug 2019 13:46:08 -0500 Subject: [PATCH 10/11] GC.SuppressFinalize for TF object. --- src/TensorFlowNET.Core/Graphs/Graph.Export.cs | 5 +- src/TensorFlowNET.Core/Graphs/Graph.Import.cs | 14 +- src/TensorFlowNET.Core/Graphs/Graph.cs | 12 +- .../Sessions/BaseSession.cs | 5 +- src/TensorFlowNET.Core/Sessions/Session.cs | 17 +- .../TensorFlowNET.Core.csproj | 11 +- .../Tensors/Tensor.Creation.cs | 1355 +++++++++-------- src/TensorFlowNET.Core/Tensors/Tensor.cs | 44 +- .../Tensors/c_api.tensor.cs | 6 + 9 files changed, 801 insertions(+), 668 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs index 67b93191..60657038 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Export.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Export.cs @@ -31,8 +31,9 @@ namespace Tensorflow private GraphDef _as_graph_def(bool add_shapes = false) { - var buffer = ToGraphDef(Status); - Status.Check(); + var status = new Status(); + var buffer = ToGraphDef(status); + status.Check(); var def = GraphDef.Parser.ParseFrom(buffer); buffer.Dispose(); diff --git a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs index 7fcfdbd7..af7ebfd1 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.Import.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.Import.cs @@ -43,16 +43,20 @@ namespace Tensorflow var bytes = File.ReadAllBytes(file_path); var graph_def = new Tensorflow.Buffer(bytes); var opts = c_api.TF_NewImportGraphDefOptions(); - c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, Status); - return Status; + var status = new Status(); + c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); + return status; } - public Status Import(byte[] bytes) + public Status Import(byte[] bytes, string prefix = "") { var graph_def = new Tensorflow.Buffer(bytes); var opts = c_api.TF_NewImportGraphDefOptions(); - c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, Status); - return Status; + c_api.TF_ImportGraphDefOptionsSetPrefix(opts, prefix); + var status = new Status(); + c_api.TF_GraphImportGraphDef(_handle, graph_def, opts, status); + c_api.TF_DeleteImportGraphDefOptions(opts); + return status; } public static Graph ImportFromPB(string file_path, string name = null) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 08ed95af..2acc0bb7 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -88,8 +88,7 @@ namespace Tensorflow private string _graph_key; public string graph_key => _graph_key; public string _last_loss_reduction; - public bool _is_loss_scaled_by_optimizer { get; set; } - public Status Status { get; } + public bool _is_loss_scaled_by_optimizer { get; set; } /// /// True if the graph is considered "finalized". In that case no @@ -107,7 +106,6 @@ namespace Tensorflow public Graph() { _handle = c_api.TF_NewGraph(); - Status = new Status(); _nodes_by_id = new Dictionary(); _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); @@ -117,7 +115,6 @@ namespace Tensorflow public Graph(IntPtr handle) { _handle = handle; - Status = new Status(); _nodes_by_id = new Dictionary(); _nodes_by_name = new Dictionary(); _names_in_use = new Dictionary(); @@ -448,7 +445,12 @@ namespace Tensorflow public void Dispose() { - // c_api.TF_DeleteGraph(_handle); + if (_handle != IntPtr.Zero) + c_api.TF_DeleteGraph(_handle); + + _handle = IntPtr.Zero; + + GC.SuppressFinalize(this); } /// diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs index deb82b51..47a891d6 100644 --- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs +++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs @@ -32,20 +32,19 @@ namespace Tensorflow protected int _current_version; protected byte[] _target; protected IntPtr _session; - public Status Status; public Graph graph => _graph; public BaseSession(string target = "", Graph g = null, SessionOptions opts = null) { _graph = g is null ? ops.get_default_graph() : g; - + _graph.as_default(); _target = UTF8Encoding.UTF8.GetBytes(target); SessionOptions newOpts = null; if (opts == null) newOpts = c_api.TF_NewSessionOptions(); - Status = new Status(); + var Status = new Status(); _session = c_api.TF_NewSession(_graph, opts ?? newOpts, Status); diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 191730d0..3e7dca84 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -37,7 +37,7 @@ namespace Tensorflow : base("", g, opts) { if (s == null) - s = Status; + s = new Status(); } public Session as_default() @@ -83,8 +83,19 @@ namespace Tensorflow public void Dispose() { - c_api.TF_DeleteSession(_session, Status); - Status.Dispose(); + if (_session != IntPtr.Zero) + { + var status = new Status(); + c_api.TF_DeleteSession(_session, status); + } + + _session = IntPtr.Zero; + GC.SuppressFinalize(this); + } + + ~Session() + { + Dispose(); } public void __enter__() diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 7374f82f..36ff24ac 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.10.4 + 0.10.7 Haiping Chen, Meinrad Recheis SciSharp STACK true @@ -17,7 +17,7 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.10.4.0 + 0.10.7.0 Changes since v0.9.0: 1. Added full connected Convolution Neural Network example. @@ -31,9 +31,12 @@ Docs: https://tensorflownet.readthedocs.io 9. Fix strided_slice_grad type convention error. 10. Add AbsGrad. 11. Fix Session.LoadFromSavedModel(string). -12. Add Tensor operator overloads. +12. Add Tensor operator overloads. +13. Fix default graph and operation issue when import model. +14. Fix TF_String endcode and decode. +15. Fix Tensor memory leak. 7.2 - 0.10.4.0 + 0.10.7.0 LICENSE true true diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs index f5ac5f77..a104f066 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs @@ -1,648 +1,717 @@ -/***************************************************************************** - Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -******************************************************************************/ - -using NumSharp; -using System; -using System.Linq; -using System.Numerics; -using System.Runtime.CompilerServices; -using System.Runtime.InteropServices; -using System.Text; -using static Tensorflow.c_api; - -namespace Tensorflow -{ - public partial class Tensor - { - /// - /// true if unmanaged buffer has been freed. - /// - private bool _deallocator_called => _deallocatorArgs.deallocator_called; - - /// - /// true if the Tensor was created from a managed array - /// - private bool _isPinnedArray => _deallocatorArgs.gc_handle != IntPtr.Zero; - - /// - /// True only if the Tensor object was created in a way that the Tensor object itself allocated memory or pinned a managed object. - /// False if the Tensor was created from a pointer - /// - public bool IsMemoryOwner { get; private set; } - - /// - /// This holds values that are used by the unmanaged deallocator callback - /// - private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; - - // note: they must be assigned to a static variable in order to work as unmanaged callbacks - static Deallocator _hGlobalDeallocator = FreeHGlobalMemory; - static Deallocator _gcHandleDeallocator = FreeGCHandle; - private static Deallocator _nothingDeallocator = FreeNothing; - - /// - /// Create a Tensor object from an existing TF handle - /// - /// - public Tensor(IntPtr handle) - { - _handle = handle; - IsMemoryOwner = false; - } - - /// - /// Create a new Tensor from the given unmanaged memory pointer (which must be allocated, fixed or pinned by the caller) - /// Note: the caller is responsible for freeing the memory. Calling Dispose on this object will dispose the TensorFlow tensor - /// but not the memory itself! - /// - /// Pointer to unmanaged, fixed or pinned memory which the caller owns - /// Tensor shape - /// TF data type - /// Size of the tensor in memory - public Tensor(IntPtr ptr, long[] shape, TF_DataType dType, int num_bytes) - { - _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); - IsMemoryOwner = false; - } - -#if _REGEN - %types=["sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] - %foreach types% - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(#1[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(#1)), new long[]{data.Length}, data, Marshal.SizeOf<#1>()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(#1[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(#1)), shape, data, Marshal.SizeOf<#1>()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(#1 value, TF_DataType? dType = null) - { - var v = (#1*)Marshal.AllocHGlobal(sizeof(#1)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(#1)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(#1), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - % -#else - - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(sbyte[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(sbyte)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(sbyte[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(sbyte)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(sbyte value, TF_DataType? dType = null) - { - var v = (sbyte*)Marshal.AllocHGlobal(sizeof(sbyte)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(sbyte)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(sbyte), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(byte[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(byte)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(byte[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(byte)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(byte value, TF_DataType? dType = null) - { - var v = (byte*)Marshal.AllocHGlobal(sizeof(byte)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(byte)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(byte), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(short[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(short)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(short[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(short)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(short value, TF_DataType? dType = null) - { - var v = (short*)Marshal.AllocHGlobal(sizeof(short)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(short)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(short), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(ushort[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ushort)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(ushort[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ushort)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(ushort value, TF_DataType? dType = null) - { - var v = (ushort*)Marshal.AllocHGlobal(sizeof(ushort)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(ushort)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ushort), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(int[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(int)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(int[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(int)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(int value, TF_DataType? dType = null) - { - var v = (int*)Marshal.AllocHGlobal(sizeof(int)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(int)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(int), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(uint[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(uint)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(uint[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(uint)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(uint value, TF_DataType? dType = null) - { - var v = (uint*)Marshal.AllocHGlobal(sizeof(uint)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(uint)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(uint), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(long[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(long)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(long[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(long)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(long value, TF_DataType? dType = null) - { - var v = (long*)Marshal.AllocHGlobal(sizeof(long)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(long)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(long), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(ulong[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ulong)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(ulong[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ulong)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(ulong value, TF_DataType? dType = null) - { - var v = (ulong*)Marshal.AllocHGlobal(sizeof(ulong)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(ulong)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ulong), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(float[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(float)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(float[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(float)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(float value, TF_DataType? dType = null) - { - var v = (float*)Marshal.AllocHGlobal(sizeof(float)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(float)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(float), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(double[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(double)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(double[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(double)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(double value, TF_DataType? dType = null) - { - var v = (double*)Marshal.AllocHGlobal(sizeof(double)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(double)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } - - /// - /// Create a 1d Tensor from the given linear array and shape - /// - public Tensor(Complex[] data, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(Complex)), new long[]{data.Length}, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a N-dimensional Tensor from the given array - /// - public Tensor(Complex[] data, long[] shape, TF_DataType? dType = null) - { - _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(Complex)), shape, data, Marshal.SizeOf()); - IsMemoryOwner=true; - } - - /// - /// Create a scalar Tensor from the given value - /// - public unsafe Tensor(Complex value, TF_DataType? dType = null) - { - var v = (Complex*)Marshal.AllocHGlobal(sizeof(Complex)); - *v = value; - _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); - IsMemoryOwner=true; - } -#endif - - /// - /// Create a string Tensor from the given string - /// - public unsafe Tensor(string str) - { - var buffer = Encoding.UTF8.GetBytes(str); - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); - - IntPtr tensor = c_api.TF_TensorData(handle); - Marshal.WriteInt64(tensor, 0); - fixed (byte* src = &buffer[0]) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); - _handle = handle; - status.Check(true); - } - - public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) - { - if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") - { - var buffer = nd.Data(); - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); - - IntPtr tensor = c_api.TF_TensorData(handle); - Marshal.WriteInt64(tensor, 0); - fixed (byte* src = &buffer[0]) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); - - status.Check(true); - _handle=handle; - IsMemoryOwner = false; - return; - } - _handle = Allocate(nd, tensorDType: tensorDType); - IsMemoryOwner = true; - } - - private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null) - { - IntPtr dotHandle = IntPtr.Zero; - int buffersize = 0; - - if (nd.dtype.Name != "String") +/***************************************************************************** + Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +******************************************************************************/ + +using NumSharp; +using System; +using System.Linq; +using System.Numerics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Text; +using static Tensorflow.c_api; + +namespace Tensorflow +{ + public partial class Tensor + { + /// + /// true if unmanaged buffer has been freed. + /// + private bool _deallocator_called => _deallocatorArgs.deallocator_called; + + /// + /// true if the Tensor was created from a managed array + /// + private bool _isPinnedArray => _deallocatorArgs.gc_handle != IntPtr.Zero; + + /// + /// True only if the Tensor object was created in a way that the Tensor object itself allocated memory or pinned a managed object. + /// False if the Tensor was created from a pointer + /// + public bool IsMemoryOwner { get; private set; } + + /// + /// This holds values that are used by the unmanaged deallocator callback + /// + private DeallocatorArgs _deallocatorArgs = new DeallocatorArgs() { gc_handle = IntPtr.Zero }; + + // note: they must be assigned to a static variable in order to work as unmanaged callbacks + static Deallocator _hGlobalDeallocator = FreeHGlobalMemory; + static Deallocator _gcHandleDeallocator = FreeGCHandle; + private static Deallocator _nothingDeallocator = FreeNothing; + + /// + /// Create a Tensor object from an existing TF handle + /// + /// + public Tensor(IntPtr handle) + { + _handle = handle; + IsMemoryOwner = false; + } + + /// + /// Create a new Tensor from the given unmanaged memory pointer (which must be allocated, fixed or pinned by the caller) + /// Note: the caller is responsible for freeing the memory. Calling Dispose on this object will dispose the TensorFlow tensor + /// but not the memory itself! + /// + /// Pointer to unmanaged, fixed or pinned memory which the caller owns + /// Tensor shape + /// TF data type + /// Size of the tensor in memory + public Tensor(IntPtr ptr, long[] shape, TF_DataType dType, int num_bytes) + { + _handle = TF_NewTensor(dType, dims: shape, num_dims: shape.Length, data: ptr, len: (UIntPtr)num_bytes, deallocator: _nothingDeallocator, ref _deallocatorArgs); + IsMemoryOwner = false; + } + +#if _REGEN + %types=["sbyte", "bool", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"] + %foreach types% + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(#1[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(#1)), new long[]{data.Length}, data, Marshal.SizeOf<#1>()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(#1[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(#1)), shape, data, Marshal.SizeOf<#1>()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(#1 value, TF_DataType? dType = null) + { + var v = (#1*)Marshal.AllocHGlobal(sizeof(#1)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(#1)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(#1), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + % +#else + + + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(sbyte[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(sbyte)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(sbyte[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(sbyte)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(sbyte value, TF_DataType? dType = null) + { + var v = (sbyte*)Marshal.AllocHGlobal(sizeof(sbyte)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(sbyte)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(sbyte), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(bool[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(bool)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(bool[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(bool)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(bool value, TF_DataType? dType = null) + { + var v = (bool*)Marshal.AllocHGlobal(sizeof(bool)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(bool)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(bool), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(byte[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(byte)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(byte[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(byte)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(byte value, TF_DataType? dType = null) + { + var v = (byte*)Marshal.AllocHGlobal(sizeof(byte)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(byte)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(byte), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(short[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(short)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(short[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(short)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(short value, TF_DataType? dType = null) + { + var v = (short*)Marshal.AllocHGlobal(sizeof(short)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(short)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(short), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(ushort[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ushort)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(ushort[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ushort)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(ushort value, TF_DataType? dType = null) + { + var v = (ushort*)Marshal.AllocHGlobal(sizeof(ushort)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(ushort)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ushort), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(int[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(int)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(int[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(int)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(int value, TF_DataType? dType = null) + { + var v = (int*)Marshal.AllocHGlobal(sizeof(int)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(int)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(int), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(uint[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(uint)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(uint[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(uint)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(uint value, TF_DataType? dType = null) + { + var v = (uint*)Marshal.AllocHGlobal(sizeof(uint)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(uint)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(uint), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(long[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(long)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(long[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(long)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(long value, TF_DataType? dType = null) + { + var v = (long*)Marshal.AllocHGlobal(sizeof(long)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(long)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(long), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(ulong[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ulong)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(ulong[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(ulong)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(ulong value, TF_DataType? dType = null) + { + var v = (ulong*)Marshal.AllocHGlobal(sizeof(ulong)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(ulong)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ulong), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(float[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(float)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(float[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(float)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(float value, TF_DataType? dType = null) + { + var v = (float*)Marshal.AllocHGlobal(sizeof(float)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(float)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(float), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(double[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(double)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(double[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(double)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(double value, TF_DataType? dType = null) + { + var v = (double*)Marshal.AllocHGlobal(sizeof(double)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(double)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } + + /// + /// Create a 1d Tensor from the given linear array and shape + /// + public Tensor(Complex[] data, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(Complex)), new long[]{data.Length}, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a N-dimensional Tensor from the given array + /// + public Tensor(Complex[] data, long[] shape, TF_DataType? dType = null) + { + _handle = CreateTensorWithoutCopying(dType ?? dtypes.as_dtype(typeof(Complex)), shape, data, Marshal.SizeOf()); + IsMemoryOwner=true; + } + + /// + /// Create a scalar Tensor from the given value + /// + public unsafe Tensor(Complex value, TF_DataType? dType = null) + { + var v = (Complex*)Marshal.AllocHGlobal(sizeof(Complex)); + *v = value; + _handle = TF_NewTensor(dType ?? dtypes.as_dtype(typeof(Complex)), dims:new long[0], num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: _hGlobalDeallocator, ref _deallocatorArgs); + IsMemoryOwner=true; + } +#endif + + /// + /// Create a string Tensor from the given string + /// + public unsafe Tensor(string str) + { + var status = new Status(); + var buffer = Encoding.UTF8.GetBytes(str); + var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + + IntPtr tensor = c_api.TF_TensorData(handle); + Marshal.WriteInt64(tensor, 0); + fixed (byte* src = &buffer[0]) + c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); + _handle = handle; + status.Check(true); + } + + public unsafe Tensor(NDArray nd, TF_DataType? tensorDType = null) + { + if (tensorDType == TF_DataType.TF_STRING && nd.dtype.Name == "Byte") + { + var buffer = nd.Data(); + var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + + IntPtr tensor = c_api.TF_TensorData(handle); + Marshal.WriteInt64(tensor, 0); + + var status = new Status(); + fixed (byte* src = &buffer[0]) + c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); + + status.Check(true); + _handle=handle; + IsMemoryOwner = false; + return; + } + _handle = Allocate(nd, tensorDType: tensorDType); + IsMemoryOwner = true; + } + + private unsafe IntPtr Allocate(NDArray nd, TF_DataType? tensorDType = null) + { + IntPtr dotHandle = IntPtr.Zero; + int buffersize = 0; + + if (nd.dtype.Name != "String") + { + buffersize = (nd.size * nd.dtypesize); + dotHandle = Marshal.AllocHGlobal(buffersize); + } + + var dataType = ToTFDataType(nd.dtype); + // shape + var dims = nd.shape.Select(x => (long)x).ToArray(); + var nd1 = nd.ravel(); + switch (nd.dtype.Name) + { + case "Boolean": + var boolVals = Array.ConvertAll(nd1.Data(), x => Convert.ToByte(x)); + Marshal.Copy(boolVals, 0, dotHandle, nd.size); + break; + case "Int16": + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + break; + case "Int32": + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + break; + case "Int64": + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + break; + case "Single": + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + break; + case "Double": + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + break; + case "Byte": + Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); + break; + case "String": + return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data(0)), TF_DataType.TF_STRING); + default: + throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); + } + var tfHandle = c_api.TF_NewTensor(dataType, + dims, + dims.Length, + dotHandle, + (UIntPtr)buffersize, + _hGlobalDeallocator, + ref _deallocatorArgs); + + return tfHandle; + } + + public unsafe Tensor(byte[][] buffer, long[] shape) + { + int size = 0; + foreach (var b in buffer) { - buffersize = (nd.size * nd.dtypesize); - dotHandle = Marshal.AllocHGlobal(buffersize); + size += (int)TF_StringEncodedSize((UIntPtr)b.Length); } - - var dataType = ToTFDataType(nd.dtype); - // shape - var dims = nd.shape.Select(x => (long)x).ToArray(); - var nd1 = nd.ravel(); - switch (nd.dtype.Name) + int totalSize = size + buffer.Length * 8; + ulong offset = 0; + IntPtr handle = TF_AllocateTensor(TF_DataType.TF_STRING, shape, shape.Length, (UIntPtr)totalSize); + + // Clear offset table + IntPtr pOffset = TF_TensorData(handle); + IntPtr dst = pOffset + buffer.Length * 8; + IntPtr dstLimit = pOffset + totalSize; + for (int i = 0; i < buffer.Length; i++) { - case "Boolean": - var boolVals = Array.ConvertAll(nd1.Data(), x => Convert.ToByte(x)); - Marshal.Copy(boolVals, 0, dotHandle, nd.size); - break; - case "Int16": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "Int32": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "Int64": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "Single": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "Double": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "Byte": - Marshal.Copy(nd1.Data(), 0, dotHandle, nd.size); - break; - case "String": - return new Tensor(UTF8Encoding.UTF8.GetBytes(nd.Data(0)), TF_DataType.TF_STRING); - default: - throw new NotImplementedException($"Marshal.Copy failed for {nd.dtype.Name}."); + Marshal.WriteInt64(pOffset, (long)offset); + using (var status = new Status()) + { + fixed (byte* src = &buffer[i][0]) + { + var written = TF_StringEncode(src, (UIntPtr)buffer[i].Length, (sbyte*)dst, (UIntPtr)(dstLimit.ToInt64() - dst.ToInt64()), status); + status.Check(true); + pOffset += 8; + dst += (int)written; + offset += written; + } + } } - var tfHandle = c_api.TF_NewTensor(dataType, - dims, - dims.Length, - dotHandle, - (UIntPtr)buffersize, - _hGlobalDeallocator, - ref _deallocatorArgs); - - return tfHandle; - } - - public Tensor(Operation op, int value_index, TF_DataType dtype) - { - _op = op; - _value_index = value_index; - _dtype = dtype; - _id = ops.uid(); - } - - - /// - /// Creates a new tensor from the given array without copying memory. The array is pinned down and the pointer passed on. - /// - /// Represents the tensor shape. - /// The linear array of data, the data must fit in the tensor with the specified dimensions. - /// The number of bytes in memory of a single array element - /// - /// Use the FromBuffer method to create a tensor that has the specified dimensions - /// and is initialized with data from the data array. The data is copied starting - /// at the start offset, for count bytes and is laid out into the tensor following the - /// specified dimensions. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) - { - if (dt == TF_DataType.TF_STRING && data is byte[]) - { - var buffer = (byte[])data; - var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); - var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); - IntPtr tensor = c_api.TF_TensorData(handle); - Marshal.WriteInt64(tensor, 0); - fixed (byte* src = &buffer[0]) - c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); - - status.Check(true); - return handle; - } - return CreateTensorWithoutCopying(dt, shape, data, 0, data.Length, element_size); - } - - /// - /// Creates a new tensor from a subsection of the given array without copying memory. The array is pinned down and the pointer passed on. - /// - /// Represents the tensor shape. - /// The linear array of data, the data must fit in the tensor with the specified dimensions. - /// The offset into the provided data array where the data resides. - /// The number of elements to copy from data. - /// The number of bytes in memory of a single array element - /// - /// Use the FromBuffer method to create a tensor that has the specified dimensions - /// and is initialized with data from the data array. The data is copied starting - /// at the start offset, for count bytes and is laid out into the tensor following the - /// specified dimensions. - /// - [MethodImpl(MethodImplOptions.AggressiveInlining)] - protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int start, int count, int element_size) - { - if (start < 0 || start > data.Length - count) - throw new ArgumentException($"Array length {data.Length} does not match the given shape {new Shape(shape.Cast().ToArray())}"); - - // get a handle to the pinned array which we will pass on to the tensor computation engine to use - var gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned); - _deallocatorArgs = new DeallocatorArgs() { gc_handle = GCHandle.ToIntPtr(gcHandle) }; - if (shape == null || shape.Length == 0) - return TF_NewTensor(dt, new long[0], 0, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr)(count * element_size), _gcHandleDeallocator, ref _deallocatorArgs); - else - return TF_NewTensor(dt, shape, shape.Length, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr)(count * element_size), _gcHandleDeallocator, ref _deallocatorArgs); - } - - [MonoPInvokeCallback(typeof(Deallocator))] - internal static void FreeHGlobalMemory(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) - { - if (args.deallocator_called) - return; - Marshal.FreeHGlobal(dataPtr); - args.deallocator_called = true; - } - - [MonoPInvokeCallback(typeof(Deallocator))] - internal static void FreeGCHandle(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) - { - if (args.deallocator_called || args.gc_handle == IntPtr.Zero) - return; - // note: since the ptr given to tensorflow is just the addr of the pinned object we can not directly free it! we need to free the gcHandle instead - GCHandle.FromIntPtr(args.gc_handle).Free(); - args.deallocator_called = true; - } - - [MonoPInvokeCallback(typeof(Deallocator))] - internal static void FreeNothing(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) - { - args.deallocator_called = true; - } - - } - - /// - /// This attribute can be applied to callback functions that will be invoked - /// from unmanaged code to managed code. - /// - /// - /// - /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))] - /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..} - /// - /// - public sealed class MonoPInvokeCallbackAttribute : Attribute - { - /// - /// Use this constructor to annotate the type of the callback function that - /// will be invoked from unmanaged code. - /// - /// T. - public MonoPInvokeCallbackAttribute(Type t) { } - } - -} + _handle = handle; + } + + public Tensor(Operation op, int value_index, TF_DataType dtype) + { + _op = op; + _value_index = value_index; + _dtype = dtype; + _id = ops.uid(); + } + + + /// + /// Creates a new tensor from the given array without copying memory. The array is pinned down and the pointer passed on. + /// + /// Represents the tensor shape. + /// The linear array of data, the data must fit in the tensor with the specified dimensions. + /// The number of bytes in memory of a single array element + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int element_size) + { + if (dt == TF_DataType.TF_STRING && data is byte[]) + { + var buffer = (byte[])data; + var size = c_api.TF_StringEncodedSize((UIntPtr)buffer.Length); + var handle = TF_AllocateTensor(TF_DataType.TF_STRING, IntPtr.Zero, 0, (UIntPtr)((ulong)size + 8)); + + IntPtr tensor = c_api.TF_TensorData(handle); + Marshal.WriteInt64(tensor, 0); + + var status = new Status(); + fixed (byte* src = &buffer[0]) + c_api.TF_StringEncode(src, (UIntPtr)buffer.Length, (sbyte*)(tensor + sizeof(Int64)), size, status); + + status.Check(true); + return handle; + } + return CreateTensorWithoutCopying(dt, shape, data, 0, data.Length, element_size); + } + + /// + /// Creates a new tensor from a subsection of the given array without copying memory. The array is pinned down and the pointer passed on. + /// + /// Represents the tensor shape. + /// The linear array of data, the data must fit in the tensor with the specified dimensions. + /// The offset into the provided data array where the data resides. + /// The number of elements to copy from data. + /// The number of bytes in memory of a single array element + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + protected unsafe IntPtr CreateTensorWithoutCopying(TF_DataType dt, long[] shape, Array data, int start, int count, int element_size) + { + if (start < 0 || start > data.Length - count) + throw new ArgumentException($"Array length {data.Length} does not match the given shape {new Shape(shape.Cast().ToArray())}"); + + // get a handle to the pinned array which we will pass on to the tensor computation engine to use + var gcHandle = GCHandle.Alloc(data, GCHandleType.Pinned); + _deallocatorArgs = new DeallocatorArgs() { gc_handle = GCHandle.ToIntPtr(gcHandle) }; + if (shape == null || shape.Length == 0) + return TF_NewTensor(dt, new long[0], 0, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr)(count * element_size), _gcHandleDeallocator, ref _deallocatorArgs); + else + return TF_NewTensor(dt, shape, shape.Length, gcHandle.AddrOfPinnedObject() + start * element_size, (UIntPtr)(count * element_size), _gcHandleDeallocator, ref _deallocatorArgs); + } + + [MonoPInvokeCallback(typeof(Deallocator))] + internal static void FreeHGlobalMemory(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) + { + if (args.deallocator_called) + return; + Marshal.FreeHGlobal(dataPtr); + args.deallocator_called = true; + } + + [MonoPInvokeCallback(typeof(Deallocator))] + internal static void FreeGCHandle(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) + { + if (args.deallocator_called || args.gc_handle == IntPtr.Zero) + return; + // note: since the ptr given to tensorflow is just the addr of the pinned object we can not directly free it! we need to free the gcHandle instead + GCHandle.FromIntPtr(args.gc_handle).Free(); + args.deallocator_called = true; + } + + [MonoPInvokeCallback(typeof(Deallocator))] + internal static void FreeNothing(IntPtr dataPtr, IntPtr len, ref DeallocatorArgs args) + { + args.deallocator_called = true; + } + + } + + /// + /// This attribute can be applied to callback functions that will be invoked + /// from unmanaged code to managed code. + /// + /// + /// + /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))] + /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..} + /// + /// + public sealed class MonoPInvokeCallbackAttribute : Attribute + { + /// + /// Use this constructor to annotate the type of the callback function that + /// will be invoked from unmanaged code. + /// + /// T. + public MonoPInvokeCallbackAttribute(Type t) { } + } + +} diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 54b58122..700673b7 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -19,6 +19,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; +using System.Text; using Tensorflow.Framework; using static Tensorflow.Python; @@ -48,8 +49,6 @@ namespace Tensorflow private int _value_index; public int value_index => _value_index; - private Status status = new Status(); - private TF_DataType _dtype = TF_DataType.DtInvalid; public TF_DataType dtype => _handle == IntPtr.Zero ? _dtype : c_api.TF_TensorType(_handle); @@ -76,6 +75,7 @@ namespace Tensorflow if (_handle == IntPtr.Zero) { + var status = new Status(); c_api.TF_GraphGetTensorShape(op.graph, _as_tf_output(), dims, rank, status); status.Check(); } @@ -90,6 +90,8 @@ namespace Tensorflow set { + var status = new Status(); + if (value == null) c_api.TF_GraphSetTensorShape(this.graph, this._as_tf_output(), null, -1, status); else @@ -131,6 +133,7 @@ namespace Tensorflow { if (_handle == IntPtr.Zero) { + var status = new Status(); var output = _as_tf_output(); return c_api.TF_GraphGetTensorNumDims(op.graph, output, status); } @@ -184,6 +187,41 @@ namespace Tensorflow return data; } + public unsafe string[] StringData() + { + // + // TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes. + // [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes] + // + long size = 1; + foreach (var s in TensorShape.Dimensions) + size *= s; + + var buffer = new byte[size][]; + var src = c_api.TF_TensorData(_handle); + var srcLen = (IntPtr)(src.ToInt64() + (long)bytesize); + src += (int)(size * 8); + for (int i = 0; i < buffer.Length; i++) + { + using (var status = new Status()) + { + IntPtr dst = IntPtr.Zero; + UIntPtr dstLen = UIntPtr.Zero; + var read = c_api.TF_StringDecode((byte*)src, (UIntPtr)(srcLen.ToInt64() - src.ToInt64()), (byte**)&dst, &dstLen, status); + status.Check(true); + buffer[i] = new byte[(int)dstLen]; + Marshal.Copy(dst, buffer[i], 0, buffer[i].Length); + src += (int)read; + } + } + + var _str = new string[buffer.Length]; + for (int i = 0; i < _str.Length; i++) + _str[i] = Encoding.UTF8.GetString(buffer[i]); + + return _str; + } + public Tensor MaybeMove() { var tensor = c_api.TF_TensorMaybeMove(_handle); @@ -364,7 +402,7 @@ namespace Tensorflow } if (h != IntPtr.Zero) c_api.TF_DeleteTensor(h); - status.Dispose(); + GC.SuppressFinalize(this); } diff --git a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs index cf208ed2..fd240ee7 100644 --- a/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/c_api.tensor.cs @@ -32,6 +32,9 @@ namespace Tensorflow [DllImport(TensorFlowLibName)] public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, IntPtr dims, int num_dims, UIntPtr len); + [DllImport(TensorFlowLibName)] + public static extern IntPtr TF_AllocateTensor(TF_DataType dtype, long[] dims, int num_dims, UIntPtr len); + /// /// returns the sizeof() for the underlying type corresponding to the given TF_DataType enum value. /// @@ -150,5 +153,8 @@ namespace Tensorflow /// [DllImport(TensorFlowLibName)] public static extern ulong TF_StringDecode(IntPtr src, ulong src_len, IntPtr dst, ref ulong dst_len, IntPtr status); + + [DllImport(TensorFlowLibName)] + public static extern unsafe UIntPtr TF_StringDecode(byte* src, UIntPtr src_len, byte** dst, UIntPtr* dst_len, IntPtr status); } } From 2c8771429c94c8780b19cd8aff82f1beda699c93 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Thu, 1 Aug 2019 15:39:57 -0500 Subject: [PATCH 11/11] release v0.10.7.2. --- src/TensorFlowNET.Core/Graphs/Graph.cs | 4 ++-- src/TensorFlowNET.Core/Sessions/Session.cs | 14 ++++++++++---- src/TensorFlowNET.Core/TensorFlowNET.Core.csproj | 6 +++--- src/TensorFlowNET.Core/Tensors/Tensor.cs | 9 +++++---- .../ImageProcessing/RetrainImageClassifier.cs | 4 ++-- 5 files changed, 22 insertions(+), 15 deletions(-) diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 2acc0bb7..7121e0be 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -445,12 +445,12 @@ namespace Tensorflow public void Dispose() { - if (_handle != IntPtr.Zero) + /*if (_handle != IntPtr.Zero) c_api.TF_DeleteGraph(_handle); _handle = IntPtr.Zero; - GC.SuppressFinalize(this); + GC.SuppressFinalize(this);*/ } /// diff --git a/src/TensorFlowNET.Core/Sessions/Session.cs b/src/TensorFlowNET.Core/Sessions/Session.cs index 3e7dca84..c85e0598 100644 --- a/src/TensorFlowNET.Core/Sessions/Session.cs +++ b/src/TensorFlowNET.Core/Sessions/Session.cs @@ -83,13 +83,19 @@ namespace Tensorflow public void Dispose() { - if (_session != IntPtr.Zero) + IntPtr h = IntPtr.Zero; + lock (this) + { + h = _session; + _session = IntPtr.Zero; + } + if (h != IntPtr.Zero) { var status = new Status(); - c_api.TF_DeleteSession(_session, status); + c_api.TF_DeleteSession(h, status); + status.Check(true); } - - _session = IntPtr.Zero; + GC.SuppressFinalize(this); } diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 36ff24ac..b7d0d36c 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -5,7 +5,7 @@ TensorFlow.NET Tensorflow 1.14.0 - 0.10.7 + 0.10.7.2 Haiping Chen, Meinrad Recheis SciSharp STACK true @@ -17,7 +17,7 @@ TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C# Google's TensorFlow full binding in .NET Standard. Docs: https://tensorflownet.readthedocs.io - 0.10.7.0 + 0.10.7.2 Changes since v0.9.0: 1. Added full connected Convolution Neural Network example. @@ -36,7 +36,7 @@ Docs: https://tensorflownet.readthedocs.io 14. Fix TF_String endcode and decode. 15. Fix Tensor memory leak. 7.2 - 0.10.7.0 + 0.10.7.2 LICENSE true true diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs index 700673b7..6d1a0783 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs @@ -135,7 +135,9 @@ namespace Tensorflow { var status = new Status(); var output = _as_tf_output(); - return c_api.TF_GraphGetTensorNumDims(op.graph, output, status); + int ndim = c_api.TF_GraphGetTensorNumDims(op.graph, output, status); + status.Check(); + return ndim; } else { @@ -394,15 +396,14 @@ namespace Tensorflow public void Dispose() { - IntPtr h=IntPtr.Zero; + IntPtr h = IntPtr.Zero; lock (this) { h = _handle; - _handle=IntPtr.Zero; + _handle = IntPtr.Zero; } if (h != IntPtr.Zero) c_api.TF_DeleteTensor(h); - GC.SuppressFinalize(this); } diff --git a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs index 96da8d1c..ee462c4c 100644 --- a/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs +++ b/test/TensorFlowNET.Examples/ImageProcessing/RetrainImageClassifier.cs @@ -83,10 +83,10 @@ namespace TensorFlowNET.Examples #region For debug purpose // predict images - Predict(null); + // Predict(null); // load saved pb and test new images. - Test(null); + // Test(null); #endregion