Browse Source

Merge branch 'master' into Conv_1D

tags/v0.40-tf2.4-tstring
Niklas Gustafsson 4 years ago
parent
commit
58f3194909
17 changed files with 121 additions and 48 deletions
  1. +21
    -3
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Data/MnistDataSet.cs
  3. +8
    -8
      src/TensorFlowNET.Core/Data/Utils.cs
  4. +8
    -8
      src/TensorFlowNET.Core/Framework/c_api_util.cs
  5. +2
    -2
      src/TensorFlowNET.Core/Framework/meta_graph.cs
  6. +1
    -1
      src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs
  7. +1
    -1
      src/TensorFlowNET.Core/Training/Saving/Saver.cs
  8. +2
    -2
      src/TensorFlowNET.Core/Training/Saving/saver.py.cs
  9. +2
    -2
      src/TensorFlowNET.Core/Util/CmdHelper.cs
  10. +1
    -1
      src/TensorFlowNET.Keras/Datasets/MNIST.cs
  11. +2
    -2
      src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs
  12. +1
    -1
      src/TensorFlowNET.Keras/Engine/Model.Fit.cs
  13. +2
    -2
      src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs
  14. +1
    -1
      src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs
  15. +8
    -8
      src/TensorFlowNET.Keras/Utils/Compress.cs
  16. +5
    -5
      src/TensorFlowNET.Keras/Utils/Web.cs
  17. +55
    -0
      test/TensorFlowNET.Keras.UnitTest/OutputTest.cs

+ 21
- 3
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -21,6 +21,7 @@ using System.Collections;
using System.Collections.Generic; using System.Collections.Generic;
using System.ComponentModel; using System.ComponentModel;
using System.Diagnostics; using System.Diagnostics;
using System.IO;
using System.Linq; using System.Linq;


namespace Tensorflow namespace Tensorflow
@@ -112,16 +113,33 @@ namespace Tensorflow
} }
} }


private static TextWriter writer = null;

public static TextWriter tf_output_redirect {
set
{
var originWriter = writer ?? Console.Out;
originWriter.Flush();
if (originWriter is StringWriter)
(originWriter as StringWriter).GetStringBuilder().Clear();
writer = value;
}
get
{
return writer ?? Console.Out;
}
}

public static void print(object obj) public static void print(object obj)
{ {
Console.WriteLine(_tostring(obj));
tf_output_redirect.WriteLine(_tostring(obj));
} }


public static void print(string format, params object[] objects) public static void print(string format, params object[] objects)
{ {
if (!format.Contains("{}")) if (!format.Contains("{}"))
{ {
Console.WriteLine(format + " " + string.Join(" ", objects.Select(x => x.ToString())));
tf_output_redirect.WriteLine(format + " " + string.Join(" ", objects.Select(x => x.ToString())));
return; return;
} }


@@ -130,7 +148,7 @@ namespace Tensorflow


} }


Console.WriteLine(format);
tf_output_redirect.WriteLine(format);
} }


public static int len(object a) public static int len(object a)


+ 1
- 1
src/TensorFlowNET.Core/Data/MnistDataSet.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
sw.Start(); sw.Start();
images = np.multiply(images, 1.0f / 255.0f); images = np.multiply(images, 1.0f / 255.0f);
sw.Stop(); sw.Stop();
Console.WriteLine($"{sw.ElapsedMilliseconds}ms");
Binding.tf_output_redirect.WriteLine($"{sw.ElapsedMilliseconds}ms");
Data = images; Data = images;


labels = labels.astype(dataType); labels = labels.astype(dataType);


+ 8
- 8
src/TensorFlowNET.Core/Data/Utils.cs View File

@@ -27,14 +27,14 @@ namespace Tensorflow


if (showProgressInConsole) if (showProgressInConsole)
{ {
Console.WriteLine($"Downloading {fileName}");
Binding.tf_output_redirect.WriteLine($"Downloading {fileName}");
} }


if (File.Exists(fileSaveTo)) if (File.Exists(fileSaveTo))
{ {
if (showProgressInConsole) if (showProgressInConsole)
{ {
Console.WriteLine($"The file {fileName} already exists");
Binding.tf_output_redirect.WriteLine($"The file {fileName} already exists");
} }


return; return;
@@ -64,12 +64,12 @@ namespace Tensorflow
var destFilePath = Path.Combine(saveTo, destFileName); var destFilePath = Path.Combine(saveTo, destFileName);


if (showProgressInConsole) if (showProgressInConsole)
Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}");
Binding.tf_output_redirect.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}");


if (File.Exists(destFilePath)) if (File.Exists(destFilePath))
{ {
if (showProgressInConsole) if (showProgressInConsole)
Console.WriteLine($"The file {destFileName} already exists");
Binding.tf_output_redirect.WriteLine($"The file {destFileName} already exists");
} }


using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress))
@@ -107,7 +107,7 @@ namespace Tensorflow
} }


await showProgressTask; await showProgressTask;
Console.WriteLine("Done.");
Binding.tf_output_redirect.WriteLine("Done.");
} }


private static async Task ShowProgressInConsole(CancellationTokenSource cts) private static async Task ShowProgressInConsole(CancellationTokenSource cts)
@@ -119,17 +119,17 @@ namespace Tensorflow
while (!cts.IsCancellationRequested) while (!cts.IsCancellationRequested)
{ {
await Task.Delay(100); await Task.Delay(100);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
cols++; cols++;


if (cols % 50 == 0) if (cols % 50 == 0)
{ {
Console.WriteLine();
Binding.tf_output_redirect.WriteLine();
} }
} }


if (cols > 0) if (cols > 0)
Console.WriteLine();
Binding.tf_output_redirect.WriteLine();
} }
} }
} }

+ 8
- 8
src/TensorFlowNET.Core/Framework/c_api_util.cs View File

@@ -62,18 +62,18 @@ namespace Tensorflow
if (!File.Exists(file)) if (!File.Exists(file))
{ {
var wc = new WebClient(); var wc = new WebClient();
Console.WriteLine($"Downloading Tensorflow library from {url}...");
Binding.tf_output_redirect.WriteLine($"Downloading Tensorflow library from {url}...");
var download = Task.Run(() => wc.DownloadFile(url, file)); var download = Task.Run(() => wc.DownloadFile(url, file));
while (!download.IsCompleted) while (!download.IsCompleted)
{ {
Thread.Sleep(1000); Thread.Sleep(1000);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
} }
Console.WriteLine("");
Console.WriteLine($"Downloaded successfully.");
Binding.tf_output_redirect.WriteLine("");
Binding.tf_output_redirect.WriteLine($"Downloaded successfully.");
} }


Console.WriteLine($"Extracting...");
Binding.tf_output_redirect.WriteLine($"Extracting...");
var task = Task.Run(() => var task = Task.Run(() =>
{ {
switch (Environment.OSVersion.Platform) switch (Environment.OSVersion.Platform)
@@ -97,11 +97,11 @@ namespace Tensorflow
while (!task.IsCompleted) while (!task.IsCompleted)
{ {
Thread.Sleep(100); Thread.Sleep(100);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
} }


Console.WriteLine("");
Console.WriteLine("Extraction is completed.");
Binding.tf_output_redirect.WriteLine("");
Binding.tf_output_redirect.WriteLine("Extraction is completed.");
} }


isDllDownloaded = true; isDllDownloaded = true;


+ 2
- 2
src/TensorFlowNET.Core/Framework/meta_graph.cs View File

@@ -134,7 +134,7 @@ namespace Tensorflow
} }
break; break;
default: default:
Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}");
Binding.tf_output_redirect.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}");
continue; continue;
} }
} }
@@ -142,7 +142,7 @@ namespace Tensorflow


break; break;
default: default:
Console.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping.");
Binding.tf_output_redirect.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping.");
break; break;
} }
} }


+ 1
- 1
src/TensorFlowNET.Core/Operations/Distributions/distribution.py.cs View File

@@ -166,7 +166,7 @@ namespace Tensorflow


public void repr() public void repr()
{ {
Console.WriteLine($"<Reparameteriation Type: {this._rep_type}>");
Binding.tf_output_redirect.WriteLine($"<Reparameteriation Type: {this._rep_type}>");
} }


public bool eq(ReparameterizationType other) public bool eq(ReparameterizationType other)


+ 1
- 1
src/TensorFlowNET.Core/Training/Saving/Saver.cs View File

@@ -242,7 +242,7 @@ namespace Tensorflow
if (!checkpoint_management.checkpoint_exists(save_path)) if (!checkpoint_management.checkpoint_exists(save_path))
throw new ValueError($"The passed save_path is not a valid checkpoint: {save_path}"); throw new ValueError($"The passed save_path is not a valid checkpoint: {save_path}");


Console.WriteLine($"Restoring parameters from {save_path}");
Binding.tf_output_redirect.WriteLine($"Restoring parameters from {save_path}");


if (tf.Context.executing_eagerly()) if (tf.Context.executing_eagerly())
#pragma warning disable CS0642 // Possible mistaken empty statement #pragma warning disable CS0642 // Possible mistaken empty statement


+ 2
- 2
src/TensorFlowNET.Core/Training/Saving/saver.py.cs View File

@@ -78,7 +78,7 @@ namespace Tensorflow
else else
{ {
// If no graph variables exist, then a Saver cannot be constructed. // If no graph variables exist, then a Saver cannot be constructed.
Console.WriteLine("Saver not created because there are no variables in the" +
Binding.tf_output_redirect.WriteLine("Saver not created because there are no variables in the" +
" graph to restore"); " graph to restore");
return null; return null;
} }
@@ -102,7 +102,7 @@ namespace Tensorflow
var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, var output_graph_def = tf.graph_util.convert_variables_to_constants(sess,
graph.as_graph_def(), graph.as_graph_def(),
output_node_names); output_node_names);
Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes.");
Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes.");
File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); File.WriteAllBytes(output_pb, output_graph_def.ToByteArray());
return output_pb; return output_pb;
} }


+ 2
- 2
src/TensorFlowNET.Core/Util/CmdHelper.cs View File

@@ -31,7 +31,7 @@ namespace Tensorflow.Util
proc.Start(); proc.Start();


while (!proc.StandardOutput.EndOfStream) while (!proc.StandardOutput.EndOfStream)
Console.WriteLine(proc.StandardOutput.ReadLine());
Binding.tf_output_redirect.WriteLine(proc.StandardOutput.ReadLine());
} }


public static void Bash(string command) public static void Bash(string command)
@@ -44,7 +44,7 @@ namespace Tensorflow.Util
proc.Start(); proc.Start();


while (!proc.StandardOutput.EndOfStream) while (!proc.StandardOutput.EndOfStream)
Console.WriteLine(proc.StandardOutput.ReadLine());
Binding.tf_output_redirect.WriteLine(proc.StandardOutput.ReadLine());
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Keras/Datasets/MNIST.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Datasets


if (File.Exists(fileSaveTo)) if (File.Exists(fileSaveTo))
{ {
Console.WriteLine($"The file {fileSaveTo} already exists");
Binding.tf_output_redirect.WriteLine($"The file {fileSaveTo} already exists");
return fileSaveTo; return fileSaveTo;
} }




+ 2
- 2
src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs View File

@@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine
StepsPerExecution = _steps_per_execution StepsPerExecution = _steps_per_execution
}); });


Console.WriteLine($"Testing...");
Binding.tf_output_redirect.WriteLine($"Testing...");
foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
{ {
// reset_metrics(); // reset_metrics();
@@ -58,7 +58,7 @@ namespace Tensorflow.Keras.Engine
// callbacks.on_train_batch_begin(step) // callbacks.on_train_batch_begin(step)
results = test_function(iterator); results = test_function(iterator);
} }
Console.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
} }
} }




+ 1
- 1
src/TensorFlowNET.Keras/Engine/Model.Fit.cs View File

@@ -99,7 +99,7 @@ namespace Tensorflow.Keras.Engine
if (verbose == 1) if (verbose == 1)
{ {
var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}"));
Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}");
} }
} }




+ 2
- 2
src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.get_training_or_validation_split.cs View File

@@ -24,13 +24,13 @@ namespace Tensorflow.Keras.Preprocessings
var num_val_samples = Convert.ToInt32(samples.Length * validation_split); var num_val_samples = Convert.ToInt32(samples.Length * validation_split);
if (subset == "training") if (subset == "training")
{ {
Console.WriteLine($"Using {samples.Length - num_val_samples} files for training.");
Binding.tf_output_redirect.WriteLine($"Using {samples.Length - num_val_samples} files for training.");
samples = samples.Take(samples.Length - num_val_samples).ToArray(); samples = samples.Take(samples.Length - num_val_samples).ToArray();
labels = labels.Take(labels.Length - num_val_samples).ToArray(); labels = labels.Take(labels.Length - num_val_samples).ToArray();
} }
else if (subset == "validation") else if (subset == "validation")
{ {
Console.WriteLine($"Using {num_val_samples} files for validation.");
Binding.tf_output_redirect.WriteLine($"Using {num_val_samples} files for validation.");
samples = samples.Skip(samples.Length - num_val_samples).ToArray(); samples = samples.Skip(samples.Length - num_val_samples).ToArray();
labels = labels.Skip(labels.Length - num_val_samples).ToArray(); labels = labels.Skip(labels.Length - num_val_samples).ToArray();
} }


+ 1
- 1
src/TensorFlowNET.Keras/Preprocessings/DatasetUtils.index_directory.cs View File

@@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Preprocessings
} }
} }


Console.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes.");
Binding.tf_output_redirect.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes.");
return (return_file_paths, return_labels, class_names); return (return_file_paths, return_labels, class_names);
} }
} }


+ 8
- 8
src/TensorFlowNET.Keras/Utils/Compress.cs View File

@@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Utils
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";
if (File.Exists(Path.Combine(destFolder, flag))) return; if (File.Exists(Path.Combine(destFolder, flag))) return;


Console.WriteLine($"Extracting.");
Binding.tf_output_redirect.WriteLine($"Extracting.");
var task = Task.Run(() => var task = Task.Run(() =>
{ {
ZipFile.ExtractToDirectory(gzArchiveName, destFolder); ZipFile.ExtractToDirectory(gzArchiveName, destFolder);
@@ -62,12 +62,12 @@ namespace Tensorflow.Keras.Utils
while (!task.IsCompleted) while (!task.IsCompleted)
{ {
Thread.Sleep(200); Thread.Sleep(200);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
} }


File.Create(Path.Combine(destFolder, flag)); File.Create(Path.Combine(destFolder, flag));
Console.WriteLine("");
Console.WriteLine("Extracting is completed.");
Binding.tf_output_redirect.WriteLine("");
Binding.tf_output_redirect.WriteLine("Extracting is completed.");
} }


public static void ExtractTGZ(String gzArchiveName, String destFolder) public static void ExtractTGZ(String gzArchiveName, String destFolder)
@@ -75,7 +75,7 @@ namespace Tensorflow.Keras.Utils
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";
if (File.Exists(Path.Combine(destFolder, flag))) return; if (File.Exists(Path.Combine(destFolder, flag))) return;


Console.WriteLine($"Extracting.");
Binding.tf_output_redirect.WriteLine($"Extracting.");
var task = Task.Run(() => var task = Task.Run(() =>
{ {
using (var inStream = File.OpenRead(gzArchiveName)) using (var inStream = File.OpenRead(gzArchiveName))
@@ -91,12 +91,12 @@ namespace Tensorflow.Keras.Utils
while (!task.IsCompleted) while (!task.IsCompleted)
{ {
Thread.Sleep(200); Thread.Sleep(200);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
} }


File.Create(Path.Combine(destFolder, flag)); File.Create(Path.Combine(destFolder, flag));
Console.WriteLine("");
Console.WriteLine("Extracting is completed.");
Binding.tf_output_redirect.WriteLine("");
Binding.tf_output_redirect.WriteLine("Extracting is completed.");
} }
} }
} }

+ 5
- 5
src/TensorFlowNET.Keras/Utils/Web.cs View File

@@ -36,20 +36,20 @@ namespace Tensorflow.Keras.Utils


if (File.Exists(relativeFilePath)) if (File.Exists(relativeFilePath))
{ {
Console.WriteLine($"{relativeFilePath} already exists.");
Binding.tf_output_redirect.WriteLine($"{relativeFilePath} already exists.");
return false; return false;
} }


var wc = new WebClient(); var wc = new WebClient();
Console.WriteLine($"Downloading from {url}");
Binding.tf_output_redirect.WriteLine($"Downloading from {url}");
var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath));
while (!download.IsCompleted) while (!download.IsCompleted)
{ {
Thread.Sleep(1000); Thread.Sleep(1000);
Console.Write(".");
Binding.tf_output_redirect.Write(".");
} }
Console.WriteLine("");
Console.WriteLine($"Downloaded to {relativeFilePath}");
Binding.tf_output_redirect.WriteLine("");
Binding.tf_output_redirect.WriteLine($"Downloaded to {relativeFilePath}");


return true; return true;
} }


+ 55
- 0
test/TensorFlowNET.Keras.UnitTest/OutputTest.cs View File

@@ -0,0 +1,55 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.Keras;

namespace Tensorflow.Keras.UnitTest
{
[TestClass]
public class OutputTest
{
[TestMethod]
public void OutputRedirectTest()
{
using var newOutput = new System.IO.StringWriter();
tf_output_redirect = newOutput;
var model = keras.Sequential();
model.add(keras.Input(shape: 16));
model.summary();
string output = newOutput.ToString();
Assert.IsTrue(output.StartsWith("Model: sequential"));
tf_output_redirect = null; // don't forget to change it to null !!!!
}

[TestMethod]
public void SwitchOutputsTest()
{
using var newOutput = new System.IO.StringWriter();
var model = keras.Sequential();
model.add(keras.Input(shape: 16));
model.summary(); // Console.Out

tf_output_redirect = newOutput; // change to the custom one
model.summary();
string firstOutput = newOutput.ToString();
Assert.IsTrue(firstOutput.StartsWith("Model: sequential"));

// if tf_output_reditect is StringWriter, calling "set" will make the writer clear.
tf_output_redirect = null; // null means Console.Out
model.summary();

tf_output_redirect = newOutput; // again, to test whether the newOutput is clear.
model.summary();
string secondOutput = newOutput.ToString();
Assert.IsTrue(secondOutput.StartsWith("Model: sequential"));

Assert.IsTrue(firstOutput == secondOutput);
tf_output_redirect = null; // don't forget to change it to null !!!!
}
}
}

Loading…
Cancel
Save