Browse Source

fix iterator_get_next for dataset.

tags/v0.20
Oceania2018 5 years ago
parent
commit
34efa15d4d
6 changed files with 57 additions and 10 deletions
  1. +3
    -5
      src/TensorFlowNET.Core/Binding.Util.cs
  2. +12
    -2
      src/TensorFlowNET.Core/Data/DatasetV2.cs
  3. +10
    -1
      src/TensorFlowNET.Core/Data/OwnedIterator.cs
  4. +17
    -0
      src/TensorFlowNET.Core/Exceptions/OutOfRangeError.cs
  5. +11
    -2
      src/TensorFlowNET.Core/Status/Status.cs
  6. +4
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs

+ 3
- 5
src/TensorFlowNET.Core/Binding.Util.cs View File

@@ -270,12 +270,10 @@ namespace Tensorflow
int i = 0;
foreach(var val in values)
{
if (i < start)
{
i++;
if (i++ < start)
continue;
}
yield return (i, val);
yield return (i - start, val);
}
}



+ 12
- 2
src/TensorFlowNET.Core/Data/DatasetV2.cs View File

@@ -78,9 +78,19 @@ namespace Tensorflow
{
var ownedIterator = new OwnedIterator(this);

Tensor[] results = ownedIterator.next();
while (results != null)
bool stop = false;
Tensor[] results = null;
while (!stop)
{
try
{
results = ownedIterator.next();
}
catch (StopIteration)
{
stop = true;
}

yield return (results[0], results[1]);
}
}


+ 10
- 1
src/TensorFlowNET.Core/Data/OwnedIterator.cs View File

@@ -35,6 +35,15 @@ namespace Tensorflow
}

public Tensor[] next()
=> ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes);
{
try
{
return ops.iterator_get_next(_iterator_resource, _dataset.output_types, _dataset.output_shapes);
}
catch (OutOfRangeError ex)
{
throw new StopIteration(ex.Message);
}
}
}
}

+ 17
- 0
src/TensorFlowNET.Core/Exceptions/OutOfRangeError.cs View File

@@ -0,0 +1,17 @@
using System;

namespace Tensorflow
{
public class OutOfRangeError : TensorflowException
{
public OutOfRangeError() : base()
{

}

public OutOfRangeError(string message) : base(message)
{

}
}
}

+ 11
- 2
src/TensorFlowNET.Core/Status/Status.cs View File

@@ -78,9 +78,18 @@ namespace Tensorflow
if (Code != TF_Code.TF_OK)
{
var message = Message;
Console.WriteLine(message);
if (throwException)
throw new TensorflowException(message);
{
switch (Code)
{
case TF_Code.TF_OUT_OF_RANGE:
throw new OutOfRangeError(message);
default:
Console.WriteLine(message);
throw new TensorflowException(message);
}
}
}
}



+ 4
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Implicit.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Eager;
using static Tensorflow.Binding;

namespace Tensorflow
{
@@ -21,5 +22,8 @@ namespace Tensorflow

public static implicit operator Tensor(IntPtr handle)
=> new Tensor(handle);

public static implicit operator Tensor(NDArray nd)
=> tf.convert_to_tensor(nd);
}
}

Loading…
Cancel
Save