| @@ -21,6 +21,7 @@ using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.ComponentModel; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Linq; | |||
| namespace Tensorflow | |||
| @@ -112,16 +113,33 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private static TextWriter writer = null; | |||
| public static TextWriter tf_output_redirect { | |||
| set | |||
| { | |||
| var originWriter = writer ?? Console.Out; | |||
| originWriter.Flush(); | |||
| if (originWriter is StringWriter) | |||
| (originWriter as StringWriter).GetStringBuilder().Clear(); | |||
| writer = value; | |||
| } | |||
| get | |||
| { | |||
| return writer ?? Console.Out; | |||
| } | |||
| } | |||
| public static void print(object obj) | |||
| { | |||
| Console.WriteLine(_tostring(obj)); | |||
| tf_output_redirect.WriteLine(_tostring(obj)); | |||
| } | |||
| public static void print(string format, params object[] objects) | |||
| { | |||
| if (!format.Contains("{}")) | |||
| { | |||
| Console.WriteLine(format + " " + string.Join(" ", objects.Select(x => x.ToString()))); | |||
| tf_output_redirect.WriteLine(format + " " + string.Join(" ", objects.Select(x => x.ToString()))); | |||
| return; | |||
| } | |||
| @@ -130,7 +148,7 @@ namespace Tensorflow | |||
| } | |||
| Console.WriteLine(format); | |||
| tf_output_redirect.WriteLine(format); | |||
| } | |||
| public static int len(object a) | |||
| @@ -24,7 +24,7 @@ namespace Tensorflow | |||
| sw.Start(); | |||
| images = np.multiply(images, 1.0f / 255.0f); | |||
| sw.Stop(); | |||
| Console.WriteLine($"{sw.ElapsedMilliseconds}ms"); | |||
| Binding.tf_output_redirect.WriteLine($"{sw.ElapsedMilliseconds}ms"); | |||
| Data = images; | |||
| labels = labels.astype(dataType); | |||
| @@ -27,14 +27,14 @@ namespace Tensorflow | |||
| if (showProgressInConsole) | |||
| { | |||
| Console.WriteLine($"Downloading {fileName}"); | |||
| Binding.tf_output_redirect.WriteLine($"Downloading {fileName}"); | |||
| } | |||
| if (File.Exists(fileSaveTo)) | |||
| { | |||
| if (showProgressInConsole) | |||
| { | |||
| Console.WriteLine($"The file {fileName} already exists"); | |||
| Binding.tf_output_redirect.WriteLine($"The file {fileName} already exists"); | |||
| } | |||
| return; | |||
| @@ -64,12 +64,12 @@ namespace Tensorflow | |||
| var destFilePath = Path.Combine(saveTo, destFileName); | |||
| if (showProgressInConsole) | |||
| Console.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); | |||
| Binding.tf_output_redirect.WriteLine($"Unzippinng {Path.GetFileName(zipFile)}"); | |||
| if (File.Exists(destFilePath)) | |||
| { | |||
| if (showProgressInConsole) | |||
| Console.WriteLine($"The file {destFileName} already exists"); | |||
| Binding.tf_output_redirect.WriteLine($"The file {destFileName} already exists"); | |||
| } | |||
| using (GZipStream unzipStream = new GZipStream(File.OpenRead(zipFile), CompressionMode.Decompress)) | |||
| @@ -107,7 +107,7 @@ namespace Tensorflow | |||
| } | |||
| await showProgressTask; | |||
| Console.WriteLine("Done."); | |||
| Binding.tf_output_redirect.WriteLine("Done."); | |||
| } | |||
| private static async Task ShowProgressInConsole(CancellationTokenSource cts) | |||
| @@ -119,17 +119,17 @@ namespace Tensorflow | |||
| while (!cts.IsCancellationRequested) | |||
| { | |||
| await Task.Delay(100); | |||
| Console.Write("."); | |||
| Binding.tf_output_redirect.Write("."); | |||
| cols++; | |||
| if (cols % 50 == 0) | |||
| { | |||
| Console.WriteLine(); | |||
| Binding.tf_output_redirect.WriteLine(); | |||
| } | |||
| } | |||
| if (cols > 0) | |||
| Console.WriteLine(); | |||
| Binding.tf_output_redirect.WriteLine(); | |||
| } | |||
| } | |||
| } | |||
| @@ -62,18 +62,18 @@ namespace Tensorflow | |||
| if (!File.Exists(file)) | |||
| { | |||
| var wc = new WebClient(); | |||
| Console.WriteLine($"Downloading Tensorflow library from {url}..."); | |||
| Binding.tf_output_redirect.WriteLine($"Downloading Tensorflow library from {url}..."); | |||
| var download = Task.Run(() => wc.DownloadFile(url, file)); | |||
| while (!download.IsCompleted) | |||
| { | |||
| Thread.Sleep(1000); | |||
| Console.Write("."); | |||
| Binding.tf_output_redirect.Write("."); | |||
| } | |||
| Console.WriteLine(""); | |||
| Console.WriteLine($"Downloaded successfully."); | |||
| Binding.tf_output_redirect.WriteLine(""); | |||
| Binding.tf_output_redirect.WriteLine($"Downloaded successfully."); | |||
| } | |||
| Console.WriteLine($"Extracting..."); | |||
| Binding.tf_output_redirect.WriteLine($"Extracting..."); | |||
| var task = Task.Run(() => | |||
| { | |||
| switch (Environment.OSVersion.Platform) | |||
| @@ -97,11 +97,11 @@ namespace Tensorflow | |||
| while (!task.IsCompleted) | |||
| { | |||
| Thread.Sleep(100); | |||
| Console.Write("."); | |||
| Binding.tf_output_redirect.Write("."); | |||
| } | |||
| Console.WriteLine(""); | |||
| Console.WriteLine("Extraction is completed."); | |||
| Binding.tf_output_redirect.WriteLine(""); | |||
| Binding.tf_output_redirect.WriteLine("Extraction is completed."); | |||
| } | |||
| isDllDownloaded = true; | |||
| @@ -134,7 +134,7 @@ namespace Tensorflow | |||
| } | |||
| break; | |||
| default: | |||
| Console.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}"); | |||
| Binding.tf_output_redirect.WriteLine($"import_scoped_meta_graph_with_return_elements {col.Key}"); | |||
| continue; | |||
| } | |||
| } | |||
| @@ -142,7 +142,7 @@ namespace Tensorflow | |||
| break; | |||
| default: | |||
| Console.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping."); | |||
| Binding.tf_output_redirect.WriteLine($"Cannot identify data type for collection {col.Key}. Skipping."); | |||
| break; | |||
| } | |||
| } | |||
| @@ -166,7 +166,7 @@ namespace Tensorflow | |||
| public void repr() | |||
| { | |||
| Console.WriteLine($"<Reparameteriation Type: {this._rep_type}>"); | |||
| Binding.tf_output_redirect.WriteLine($"<Reparameteriation Type: {this._rep_type}>"); | |||
| } | |||
| public bool eq(ReparameterizationType other) | |||
| @@ -242,7 +242,7 @@ namespace Tensorflow | |||
| if (!checkpoint_management.checkpoint_exists(save_path)) | |||
| throw new ValueError($"The passed save_path is not a valid checkpoint: {save_path}"); | |||
| Console.WriteLine($"Restoring parameters from {save_path}"); | |||
| Binding.tf_output_redirect.WriteLine($"Restoring parameters from {save_path}"); | |||
| if (tf.Context.executing_eagerly()) | |||
| #pragma warning disable CS0642 // Possible mistaken empty statement | |||
| @@ -78,7 +78,7 @@ namespace Tensorflow | |||
| else | |||
| { | |||
| // If no graph variables exist, then a Saver cannot be constructed. | |||
| Console.WriteLine("Saver not created because there are no variables in the" + | |||
| Binding.tf_output_redirect.WriteLine("Saver not created because there are no variables in the" + | |||
| " graph to restore"); | |||
| return null; | |||
| } | |||
| @@ -102,7 +102,7 @@ namespace Tensorflow | |||
| var output_graph_def = tf.graph_util.convert_variables_to_constants(sess, | |||
| graph.as_graph_def(), | |||
| output_node_names); | |||
| Console.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); | |||
| Binding.tf_output_redirect.WriteLine($"Froze {output_graph_def.Node.Count} nodes."); | |||
| File.WriteAllBytes(output_pb, output_graph_def.ToByteArray()); | |||
| return output_pb; | |||
| } | |||
| @@ -31,7 +31,7 @@ namespace Tensorflow.Util | |||
| proc.Start(); | |||
| while (!proc.StandardOutput.EndOfStream) | |||
| Console.WriteLine(proc.StandardOutput.ReadLine()); | |||
| Binding.tf_output_redirect.WriteLine(proc.StandardOutput.ReadLine()); | |||
| } | |||
| public static void Bash(string command) | |||
| @@ -44,7 +44,7 @@ namespace Tensorflow.Util | |||
| proc.Start(); | |||
| while (!proc.StandardOutput.EndOfStream) | |||
| Console.WriteLine(proc.StandardOutput.ReadLine()); | |||
| Binding.tf_output_redirect.WriteLine(proc.StandardOutput.ReadLine()); | |||
| } | |||
| } | |||
| } | |||
| @@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Datasets | |||
| if (File.Exists(fileSaveTo)) | |||
| { | |||
| Console.WriteLine($"The file {fileSaveTo} already exists"); | |||
| Binding.tf_output_redirect.WriteLine($"The file {fileSaveTo} already exists"); | |||
| return fileSaveTo; | |||
| } | |||
| @@ -46,7 +46,7 @@ namespace Tensorflow.Keras.Engine | |||
| StepsPerExecution = _steps_per_execution | |||
| }); | |||
| Console.WriteLine($"Testing..."); | |||
| Binding.tf_output_redirect.WriteLine($"Testing..."); | |||
| foreach (var (epoch, iterator) in data_handler.enumerate_epochs()) | |||
| { | |||
| // reset_metrics(); | |||
| @@ -58,7 +58,7 @@ namespace Tensorflow.Keras.Engine | |||
| // callbacks.on_train_batch_begin(step) | |||
| results = test_function(iterator); | |||
| } | |||
| Console.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | |||
| Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}"))); | |||
| } | |||
| } | |||
| @@ -99,7 +99,7 @@ namespace Tensorflow.Keras.Engine | |||
| if (verbose == 1) | |||
| { | |||
| var result_pairs = string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2:F6}")); | |||
| Console.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); | |||
| Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{epochs:D3}, Step: {step + 1:D4}/{data_handler.Inferredsteps:D4}, {result_pairs}"); | |||
| } | |||
| } | |||
| @@ -24,13 +24,13 @@ namespace Tensorflow.Keras.Preprocessings | |||
| var num_val_samples = Convert.ToInt32(samples.Length * validation_split); | |||
| if (subset == "training") | |||
| { | |||
| Console.WriteLine($"Using {samples.Length - num_val_samples} files for training."); | |||
| Binding.tf_output_redirect.WriteLine($"Using {samples.Length - num_val_samples} files for training."); | |||
| samples = samples.Take(samples.Length - num_val_samples).ToArray(); | |||
| labels = labels.Take(labels.Length - num_val_samples).ToArray(); | |||
| } | |||
| else if (subset == "validation") | |||
| { | |||
| Console.WriteLine($"Using {num_val_samples} files for validation."); | |||
| Binding.tf_output_redirect.WriteLine($"Using {num_val_samples} files for validation."); | |||
| samples = samples.Skip(samples.Length - num_val_samples).ToArray(); | |||
| labels = labels.Skip(labels.Length - num_val_samples).ToArray(); | |||
| } | |||
| @@ -61,7 +61,7 @@ namespace Tensorflow.Keras.Preprocessings | |||
| } | |||
| } | |||
| Console.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes."); | |||
| Binding.tf_output_redirect.WriteLine($"Found {return_file_paths.Length} files belonging to {class_names.Length} classes."); | |||
| return (return_file_paths, return_labels, class_names); | |||
| } | |||
| } | |||
| @@ -53,7 +53,7 @@ namespace Tensorflow.Keras.Utils | |||
| var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; | |||
| if (File.Exists(Path.Combine(destFolder, flag))) return; | |||
| Console.WriteLine($"Extracting."); | |||
| Binding.tf_output_redirect.WriteLine($"Extracting."); | |||
| var task = Task.Run(() => | |||
| { | |||
| ZipFile.ExtractToDirectory(gzArchiveName, destFolder); | |||
| @@ -62,12 +62,12 @@ namespace Tensorflow.Keras.Utils | |||
| while (!task.IsCompleted) | |||
| { | |||
| Thread.Sleep(200); | |||
| Console.Write("."); | |||
| Binding.tf_output_redirect.Write("."); | |||
| } | |||
| File.Create(Path.Combine(destFolder, flag)); | |||
| Console.WriteLine(""); | |||
| Console.WriteLine("Extracting is completed."); | |||
| Binding.tf_output_redirect.WriteLine(""); | |||
| Binding.tf_output_redirect.WriteLine("Extracting is completed."); | |||
| } | |||
| public static void ExtractTGZ(String gzArchiveName, String destFolder) | |||
| @@ -75,7 +75,7 @@ namespace Tensorflow.Keras.Utils | |||
| var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin"; | |||
| if (File.Exists(Path.Combine(destFolder, flag))) return; | |||
| Console.WriteLine($"Extracting."); | |||
| Binding.tf_output_redirect.WriteLine($"Extracting."); | |||
| var task = Task.Run(() => | |||
| { | |||
| using (var inStream = File.OpenRead(gzArchiveName)) | |||
| @@ -91,12 +91,12 @@ namespace Tensorflow.Keras.Utils | |||
| while (!task.IsCompleted) | |||
| { | |||
| Thread.Sleep(200); | |||
| Console.Write("."); | |||
| Binding.tf_output_redirect.Write("."); | |||
| } | |||
| File.Create(Path.Combine(destFolder, flag)); | |||
| Console.WriteLine(""); | |||
| Console.WriteLine("Extracting is completed."); | |||
| Binding.tf_output_redirect.WriteLine(""); | |||
| Binding.tf_output_redirect.WriteLine("Extracting is completed."); | |||
| } | |||
| } | |||
| } | |||
| @@ -36,20 +36,20 @@ namespace Tensorflow.Keras.Utils | |||
| if (File.Exists(relativeFilePath)) | |||
| { | |||
| Console.WriteLine($"{relativeFilePath} already exists."); | |||
| Binding.tf_output_redirect.WriteLine($"{relativeFilePath} already exists."); | |||
| return false; | |||
| } | |||
| var wc = new WebClient(); | |||
| Console.WriteLine($"Downloading from {url}"); | |||
| Binding.tf_output_redirect.WriteLine($"Downloading from {url}"); | |||
| var download = Task.Run(() => wc.DownloadFile(url, relativeFilePath)); | |||
| while (!download.IsCompleted) | |||
| { | |||
| Thread.Sleep(1000); | |||
| Console.Write("."); | |||
| Binding.tf_output_redirect.Write("."); | |||
| } | |||
| Console.WriteLine(""); | |||
| Console.WriteLine($"Downloaded to {relativeFilePath}"); | |||
| Binding.tf_output_redirect.WriteLine(""); | |||
| Binding.tf_output_redirect.WriteLine($"Downloaded to {relativeFilePath}"); | |||
| return true; | |||
| } | |||
| @@ -0,0 +1,55 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| using Tensorflow.Keras; | |||
| namespace Tensorflow.Keras.UnitTest | |||
| { | |||
| [TestClass] | |||
| public class OutputTest | |||
| { | |||
| [TestMethod] | |||
| public void OutputRedirectTest() | |||
| { | |||
| using var newOutput = new System.IO.StringWriter(); | |||
| tf_output_redirect = newOutput; | |||
| var model = keras.Sequential(); | |||
| model.add(keras.Input(shape: 16)); | |||
| model.summary(); | |||
| string output = newOutput.ToString(); | |||
| Assert.IsTrue(output.StartsWith("Model: sequential")); | |||
| tf_output_redirect = null; // don't forget to change it to null !!!! | |||
| } | |||
| [TestMethod] | |||
| public void SwitchOutputsTest() | |||
| { | |||
| using var newOutput = new System.IO.StringWriter(); | |||
| var model = keras.Sequential(); | |||
| model.add(keras.Input(shape: 16)); | |||
| model.summary(); // Console.Out | |||
| tf_output_redirect = newOutput; // change to the custom one | |||
| model.summary(); | |||
| string firstOutput = newOutput.ToString(); | |||
| Assert.IsTrue(firstOutput.StartsWith("Model: sequential")); | |||
| // if tf_output_reditect is StringWriter, calling "set" will make the writer clear. | |||
| tf_output_redirect = null; // null means Console.Out | |||
| model.summary(); | |||
| tf_output_redirect = newOutput; // again, to test whether the newOutput is clear. | |||
| model.summary(); | |||
| string secondOutput = newOutput.ToString(); | |||
| Assert.IsTrue(secondOutput.StartsWith("Model: sequential")); | |||
| Assert.IsTrue(firstOutput == secondOutput); | |||
| tf_output_redirect = null; // don't forget to change it to null !!!! | |||
| } | |||
| } | |||
| } | |||