From 76abb2c6f024484124d525bf6923f719cb4080fe Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Tue, 20 Jul 2021 06:57:04 -0500 Subject: [PATCH] to_numpy_string --- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 85 +++++++++++++++++-- .../Numpy/Array.Creation.Test.cs | 10 +++ 2 files changed, 88 insertions(+), 7 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index f694de82..b93a5982 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -470,7 +470,20 @@ would not be rank 1.", tensor.op.get_attr("axis"))); return ops.convert_to_tensor(shape, dtype: TF_DataType.TF_INT32, name: "shape"); } - public static string to_numpy_string(Tensor tensor) + public static string to_numpy_string(Tensor array) + { + Shape shape = array.shape; + if (shape.ndim == 0) + return array[0].ToString(); + + var s = new StringBuilder(); + s.Append("array("); + PrettyPrint(s, array); + s.Append(")"); + return s.ToString(); + } + + static string Render(Tensor tensor) { if (tensor.buffer == IntPtr.Zero) return "Empty"; @@ -487,7 +500,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); else return $"['{string.Join("', '", tensor.StringData().Take(25))}']"; } - else if(dtype == TF_DataType.TF_VARIANT) + else if (dtype == TF_DataType.TF_VARIANT) { return ""; } @@ -515,7 +528,7 @@ would not be rank 1.", tensor.op.get_attr("axis"))); var array = tensor.ToArray(); return DisplayArrayAsString(array, tensor.shape); } - else if(dtype == TF_DataType.TF_DOUBLE) + else if (dtype == TF_DataType.TF_DOUBLE) { var array = tensor.ToArray(); return DisplayArrayAsString(array, tensor.shape); @@ -532,14 +545,72 @@ would not be rank 1.", tensor.op.get_attr("axis"))); if (shape.ndim == 0) return array[0].ToString(); - var display = "array(["; + var display = ""; if (array.Length < 10) display += string.Join(", ", array); else - display += string.Join(", ", array.Take(3)) + " ... " + string.Join(", ", array.Skip(array.Length - 3)); - return display + "])"; + display += string.Join(", ", array.Take(3)) + ", ..., " + string.Join(", ", array.Skip(array.Length - 3)); + return display; + } + + static void PrettyPrint(StringBuilder s, Tensor array, bool flat = false) + { + var shape = array.shape; + + if (shape.Length == 1) + { + s.Append("["); + s.Append(Render(array)); + s.Append("]"); + return; + } + + var len = shape[0]; + s.Append("["); + + if (len <= 10) + { + for (int i = 0; i < len; i++) + { + PrettyPrint(s, array[i], flat); + if (i < len - 1) + { + s.Append(", "); + if (!flat) + s.AppendLine(); + } + } + } + else + { + for (int i = 0; i < 5; i++) + { + PrettyPrint(s, array[i], flat); + if (i < len - 1) + { + s.Append(", "); + if (!flat) + s.AppendLine(); + } + } + + s.Append(" ... "); + s.AppendLine(); + + for (int i = (int)array.size - 5; i < len; i++) + { + PrettyPrint(s, array[i], flat); + if (i < len - 1) + { + s.Append(", "); + if (!flat) + s.AppendLine(); + } + } + } + + s.Append("]"); } - public static ParsedSliceArgs ParseSlices(Slice[] slices) { diff --git a/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs index 112e2d6c..0e024fd1 100644 --- a/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs +++ b/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Text; +using Tensorflow; using Tensorflow.NumPy; namespace TensorFlowNET.UnitTest.NumPy @@ -88,5 +89,14 @@ namespace TensorFlowNET.UnitTest.NumPy AssetSequenceEqual(a.ToArray(), new int[] { 0, 1, 2, 0, 1, 2, 0, 1, 2 }); AssetSequenceEqual(b.ToArray(), new int[] { 0, 0, 0, 1, 1, 1, 2, 2, 2 }); } + + [TestMethod] + public void to_numpy_string() + { + var nd = np.arange(10 * 10 * 10 * 10).reshape((10, 10, 10, 10)); + var str = tensor_util.to_numpy_string(nd); + Assert.AreEqual("array([[[[0, 1, 2, ..., 7, 8, 9],", str.Substring(0, 33)); + Assert.AreEqual("[9990, 9991, 9992, ..., 9997, 9998, 9999]]]])", str.Substring(str.Length - 45)); + } } }