diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs index 68c2319c..1e8c21a9 100644 --- a/src/TensorFlowNET.Core/Binding.Util.cs +++ b/src/TensorFlowNET.Core/Binding.Util.cs @@ -270,12 +270,10 @@ namespace Tensorflow int i = 0; foreach(var val in values) { - if (i < start) - { - i++; + if (i++ < start) continue; - } - yield return (i, val); + + yield return (i - start, val); } } diff --git a/src/TensorFlowNET.Core/Data/DatasetV2.cs b/src/TensorFlowNET.Core/Data/DatasetV2.cs index 09bc38ff..071cbb29 100644 --- a/src/TensorFlowNET.Core/Data/DatasetV2.cs +++ b/src/TensorFlowNET.Core/Data/DatasetV2.cs @@ -78,9 +78,19 @@ namespace Tensorflow { var ownedIterator = new OwnedIterator(this); - Tensor[] results = ownedIterator.next(); - while (results != null) + bool stop = false; + Tensor[] results = null; + while (!stop) { + try + { + results = ownedIterator.next(); + } + catch (StopIteration) + { + stop = true; + } + yield return (results[0], results[1]); } } diff --git a/src/TensorFlowNET.Core/Data/OwnedIterator.cs b/src/TensorFlowNET.Core/Data/OwnedIterator.cs index bde214e8..cd9e8820 100644 --- a/src/TensorFlowNET.Core/Data/OwnedIterator.cs +++ b/src/TensorFlowNET.Core/Data/OwnedIterator.cs @@ -35,6 +35,15 @@ namespace Tensorflow } public Tensor[] next() - => ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); + { + try + { + return ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes); + } + catch (OutOfRangeError ex) + { + throw new StopIteration(ex.Message); + } + } } } diff --git a/src/TensorFlowNET.Core/Exceptions/OutOfRangeError.cs b/src/TensorFlowNET.Core/Exceptions/OutOfRangeError.cs new file mode 100644 index 00000000..422ff059 --- /dev/null +++ b/src/TensorFlowNET.Core/Exceptions/OutOfRangeError.cs @@ -0,0 +1,17 @@ +using System; + +namespace Tensorflow +{ + public class OutOfRangeError : TensorflowException + { + public OutOfRangeError() : base() + { + + } + + public OutOfRangeError(string message) : base(message) + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Status/Status.cs b/src/TensorFlowNET.Core/Status/Status.cs index 95ab82d9..30051fb7 100644 --- a/src/TensorFlowNET.Core/Status/Status.cs +++ b/src/TensorFlowNET.Core/Status/Status.cs @@ -78,9 +78,18 @@ namespace Tensorflow if (Code != TF_Code.TF_OK) { var message = Message; - Console.WriteLine(message); + if (throwException) - throw new TensorflowException(message); + { + switch (Code) + { + case TF_Code.TF_OUT_OF_RANGE: + throw new OutOfRangeError(message); + default: + Console.WriteLine(message); + throw new TensorflowException(message); + } + } } } diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs index bcf895db..5203d43e 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; using System.Text; using Tensorflow.Eager; +using static Tensorflow.Binding; namespace Tensorflow { @@ -21,5 +22,8 @@ namespace Tensorflow public static implicit operator Tensor(IntPtr handle) => new Tensor(handle); + + public static implicit operator Tensor(NDArray nd) + => tf.convert_to_tensor(nd); } }