From 6a295b68fc0a56423aa05f55c5cdb3d3668d6193 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sun, 5 Mar 2023 01:12:41 +0800 Subject: [PATCH] Add more explicit conversion for Tensors. --- src/TensorFlowNET.Core/Tensors/Tensors.cs | 101 ++++++++++++++++++ .../Dataset/DatasetTest.cs | 18 ++-- 2 files changed, 110 insertions(+), 9 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensors.cs b/src/TensorFlowNET.Core/Tensors/Tensors.cs index ecd844d1..7fa4dd44 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensors.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensors.cs @@ -65,6 +65,93 @@ namespace Tensorflow IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + public NDArray numpy() + { + EnsureSingleTensor(this, "nnumpy"); + return this[0].numpy(); + } + + public T[] ToArray() where T: unmanaged + { + EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); + return this[0].ToArray(); + } + + #region Explicit Conversions + public unsafe static explicit operator bool(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to bool"); + return (bool)tensor[0]; + } + + public unsafe static explicit operator sbyte(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to sbyte"); + return (sbyte)tensor[0]; + } + + public unsafe static explicit operator byte(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to byte"); + return (byte)tensor[0]; + } + + public unsafe static explicit operator ushort(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to ushort"); + return (ushort)tensor[0]; + } + + public unsafe static explicit operator short(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to short"); + return (short)tensor[0]; + } + + public unsafe static explicit operator int(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to int"); + return (int)tensor[0]; + } + + public unsafe static explicit operator uint(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to uint"); + return (uint)tensor[0]; + } + + public unsafe static explicit operator long(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to long"); + return (long)tensor[0]; + } + + public unsafe static explicit operator ulong(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to ulong"); + return (ulong)tensor[0]; + } + + public unsafe static explicit operator float(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to byte"); + return (byte)tensor[0]; + } + + public unsafe static explicit operator double(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to double"); + return (double)tensor[0]; + } + + public unsafe static explicit operator string(Tensors tensor) + { + EnsureSingleTensor(tensor, "explicit conversion to string"); + return (string)tensor[0]; + } + #endregion + + #region Implicit Conversions public static implicit operator Tensors(Tensor tensor) => new Tensors(tensor); @@ -87,12 +174,26 @@ namespace Tensorflow public static implicit operator Tensor[](Tensors tensors) => tensors.items.ToArray(); + #endregion + public void Deconstruct(out Tensor a, out Tensor b) { a = items[0]; b = items[1]; } + private static void EnsureSingleTensor(Tensors tensors, string methodnName) + { + if(tensors.Length == 0) + { + throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); + } + else if(tensors.Length > 1) + { + throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); + } + } + public override string ToString() => items.Count() == 1 ? items.First().ToString() diff --git a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs index 01f35a41..8317346e 100644 --- a/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs +++ b/test/TensorFlowNET.UnitTest/Dataset/DatasetTest.cs @@ -20,7 +20,7 @@ namespace TensorFlowNET.UnitTest.Dataset Assert.AreEqual(iStep, step); iStep++; - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value++; } } @@ -39,7 +39,7 @@ namespace TensorFlowNET.UnitTest.Dataset Assert.AreEqual(iStep, step); iStep++; - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value += 2; } } @@ -54,7 +54,7 @@ namespace TensorFlowNET.UnitTest.Dataset int n = 0; foreach (var (item_x, item_y) in dataset) { - print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}"); + print($"x:{item_x.numpy()},y:{item_y.numpy()}"); n += 1; } Assert.AreEqual(5, n); @@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest.Dataset int n = 0; foreach (var x in dataset) { - Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray())); + Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray())); n += 1; } Assert.AreEqual(1, n); @@ -85,7 +85,7 @@ namespace TensorFlowNET.UnitTest.Dataset foreach (var item in dataset2) { - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value += 3; } @@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest.Dataset var dataset3 = dataset1.shard(num_shards: 3, index: 1); foreach (var item in dataset3) { - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value += 3; } } @@ -108,7 +108,7 @@ namespace TensorFlowNET.UnitTest.Dataset foreach (var item in dataset) { - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value++; } } @@ -123,7 +123,7 @@ namespace TensorFlowNET.UnitTest.Dataset foreach (var item in dataset) { - Assert.AreEqual(value + 10, (long)item.Item1[0]); + Assert.AreEqual(value + 10, (long)item.Item1); value++; } } @@ -138,7 +138,7 @@ namespace TensorFlowNET.UnitTest.Dataset foreach (var item in dataset) { - Assert.AreEqual(value, (long)item.Item1[0]); + Assert.AreEqual(value, (long)item.Item1); value++; } }