| @@ -65,6 +65,93 @@ namespace Tensorflow | |||||
| IEnumerator IEnumerable.GetEnumerator() | IEnumerator IEnumerable.GetEnumerator() | ||||
| => GetEnumerator(); | => GetEnumerator(); | ||||
| public NDArray numpy() | |||||
| { | |||||
| EnsureSingleTensor(this, "nnumpy"); | |||||
| return this[0].numpy(); | |||||
| } | |||||
| public T[] ToArray<T>() where T: unmanaged | |||||
| { | |||||
| EnsureSingleTensor(this, $"ToArray<{typeof(T)}>"); | |||||
| return this[0].ToArray<T>(); | |||||
| } | |||||
| #region Explicit Conversions | |||||
| public unsafe static explicit operator bool(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to bool"); | |||||
| return (bool)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator sbyte(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to sbyte"); | |||||
| return (sbyte)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator byte(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||||
| return (byte)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator ushort(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to ushort"); | |||||
| return (ushort)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator short(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to short"); | |||||
| return (short)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator int(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to int"); | |||||
| return (int)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator uint(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to uint"); | |||||
| return (uint)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator long(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to long"); | |||||
| return (long)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator ulong(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to ulong"); | |||||
| return (ulong)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator float(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to byte"); | |||||
| return (byte)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator double(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to double"); | |||||
| return (double)tensor[0]; | |||||
| } | |||||
| public unsafe static explicit operator string(Tensors tensor) | |||||
| { | |||||
| EnsureSingleTensor(tensor, "explicit conversion to string"); | |||||
| return (string)tensor[0]; | |||||
| } | |||||
| #endregion | |||||
| #region Implicit Conversions | |||||
| public static implicit operator Tensors(Tensor tensor) | public static implicit operator Tensors(Tensor tensor) | ||||
| => new Tensors(tensor); | => new Tensors(tensor); | ||||
| @@ -87,12 +174,26 @@ namespace Tensorflow | |||||
| public static implicit operator Tensor[](Tensors tensors) | public static implicit operator Tensor[](Tensors tensors) | ||||
| => tensors.items.ToArray(); | => tensors.items.ToArray(); | ||||
| #endregion | |||||
| public void Deconstruct(out Tensor a, out Tensor b) | public void Deconstruct(out Tensor a, out Tensor b) | ||||
| { | { | ||||
| a = items[0]; | a = items[0]; | ||||
| b = items[1]; | b = items[1]; | ||||
| } | } | ||||
| private static void EnsureSingleTensor(Tensors tensors, string methodnName) | |||||
| { | |||||
| if(tensors.Length == 0) | |||||
| { | |||||
| throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains no Tensor."); | |||||
| } | |||||
| else if(tensors.Length > 1) | |||||
| { | |||||
| throw new ValueError($"Method `{methodnName}` of `Tensors` cannot be used when `Tensors` contains more than one Tensor."); | |||||
| } | |||||
| } | |||||
| public override string ToString() | public override string ToString() | ||||
| => items.Count() == 1 | => items.Count() == 1 | ||||
| ? items.First().ToString() | ? items.First().ToString() | ||||
| @@ -20,7 +20,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| Assert.AreEqual(iStep, step); | Assert.AreEqual(iStep, step); | ||||
| iStep++; | iStep++; | ||||
| Assert.AreEqual(value, (long)item.Item1[0]); | |||||
| Assert.AreEqual(value, (long)item.Item1); | |||||
| value++; | value++; | ||||
| } | } | ||||
| } | } | ||||
| @@ -39,7 +39,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| Assert.AreEqual(iStep, step); | Assert.AreEqual(iStep, step); | ||||
| iStep++; | iStep++; | ||||
| Assert.AreEqual(value, (long)item.Item1[0]); | |||||
| Assert.AreEqual(value, (long)item.Item1); | |||||
| value += 2; | value += 2; | ||||
| } | } | ||||
| } | } | ||||
| @@ -54,7 +54,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| int n = 0; | int n = 0; | ||||
| foreach (var (item_x, item_y) in dataset) | foreach (var (item_x, item_y) in dataset) | ||||
| { | { | ||||
| print($"x:{item_x[0].numpy()},y:{item_y[0].numpy()}"); | |||||
| print($"x:{item_x.numpy()},y:{item_y.numpy()}"); | |||||
| n += 1; | n += 1; | ||||
| } | } | ||||
| Assert.AreEqual(5, n); | Assert.AreEqual(5, n); | ||||
| @@ -69,7 +69,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| int n = 0; | int n = 0; | ||||
| foreach (var x in dataset) | foreach (var x in dataset) | ||||
| { | { | ||||
| Assert.IsTrue(X.SequenceEqual(x.Item1[0].ToArray<int>())); | |||||
| Assert.IsTrue(X.SequenceEqual(x.Item1.ToArray<int>())); | |||||
| n += 1; | n += 1; | ||||
| } | } | ||||
| Assert.AreEqual(1, n); | Assert.AreEqual(1, n); | ||||
| @@ -85,7 +85,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| foreach (var item in dataset2) | foreach (var item in dataset2) | ||||
| { | { | ||||
| Assert.AreEqual(value, (long)item.Item1[0]); | |||||
| Assert.AreEqual(value, (long)item.Item1); | |||||
| value += 3; | value += 3; | ||||
| } | } | ||||
| @@ -93,7 +93,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| var dataset3 = dataset1.shard(num_shards: 3, index: 1); | var dataset3 = dataset1.shard(num_shards: 3, index: 1); | ||||
| foreach (var item in dataset3) | foreach (var item in dataset3) | ||||
| { | { | ||||
| Assert.AreEqual(value, (long)item.Item1[0]); | |||||
| Assert.AreEqual(value, (long)item.Item1); | |||||
| value += 3; | value += 3; | ||||
| } | } | ||||
| } | } | ||||
| @@ -108,7 +108,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| foreach (var item in dataset) | foreach (var item in dataset) | ||||
| { | { | ||||
| Assert.AreEqual(value, (long)item.Item1[0]); | |||||
| Assert.AreEqual(value, (long)item.Item1); | |||||
| value++; | value++; | ||||
| } | } | ||||
| } | } | ||||
| @@ -123,7 +123,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| foreach (var item in dataset) | foreach (var item in dataset) | ||||
| { | { | ||||
| Assert.AreEqual(value + 10, (long)item.Item1[0]); | |||||
| Assert.AreEqual(value + 10, (long)item.Item1); | |||||
| value++; | value++; | ||||
| } | } | ||||
| } | } | ||||
| @@ -138,7 +138,7 @@ namespace TensorFlowNET.UnitTest.Dataset | |||||
| foreach (var item in dataset) | foreach (var item in dataset) | ||||
| { | { | ||||
| Assert.AreEqual(value, (long)item.Item1[0]); | |||||
| Assert.AreEqual(value, (long)item.Item1); | |||||
| value++; | value++; | ||||
| } | } | ||||
| } | } | ||||