diff --git a/src/TensorFlowNET.Console/Tensorflow.Console.csproj b/src/TensorFlowNET.Console/Tensorflow.Console.csproj
index 43d7a740..e6b2ea1d 100644
--- a/src/TensorFlowNET.Console/Tensorflow.Console.csproj
+++ b/src/TensorFlowNET.Console/Tensorflow.Console.csproj
@@ -14,8 +14,12 @@
x64
+
+ DEBUG;TRACE
+
+
-
+
diff --git a/src/TensorFlowNET.Core/Binding.Util.cs b/src/TensorFlowNET.Core/Binding.Util.cs
index edc8edcc..b7aca46b 100644
--- a/src/TensorFlowNET.Core/Binding.Util.cs
+++ b/src/TensorFlowNET.Core/Binding.Util.cs
@@ -271,14 +271,18 @@ namespace Tensorflow
}
}
- public static IEnumerable<(T, T)> zip(NDArray t1, NDArray t2)
+ public static IEnumerable<(T, T)> zip(NDArray t1, NDArray t2, Axis axis = null)
where T : unmanaged
{
- /*var a = t1.AsIterator();
- var b = t2.AsIterator();
- while (a.HasNext() && b.HasNext())
- yield return (a.MoveNext(), b.MoveNext());*/
- throw new NotImplementedException("");
+ if (axis == null)
+ {
+ var a = t1.Data();
+ var b = t2.Data();
+ for (int i = 0; i < a.Length; i++)
+ yield return (a[i], b[i]);
+ }
+ else
+ throw new NotImplementedException("");
}
public static IEnumerable<(T1, T2)> zip(IList t1, IList t2)
diff --git a/src/TensorFlowNET.Core/Data/MnistModelLoader.cs b/src/TensorFlowNET.Core/Data/MnistModelLoader.cs
index f99c1e5d..e7251217 100644
--- a/src/TensorFlowNET.Core/Data/MnistModelLoader.cs
+++ b/src/TensorFlowNET.Core/Data/MnistModelLoader.cs
@@ -166,7 +166,7 @@ namespace Tensorflow
for (int row = 0; row < num_labels; row++)
{
var col = labels[row];
- labels_one_hot.SetData(1.0, row, col);
+ labels_one_hot[row, col] = 1.0;
}
return labels_one_hot;
diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs
index 86914512..d9ad9ae6 100644
--- a/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs
+++ b/src/TensorFlowNET.Core/NumPy/NDArray.Equal.cs
@@ -25,6 +25,10 @@ namespace Tensorflow.NumPy
{
if (x.ndim != y.ndim)
return false;
+ else if (x.size != y.size)
+ return false;
+ else if (x.dtype != y.dtype)
+ return false;
return Enumerable.SequenceEqual(x.ToByteArray(), y.ToByteArray());
}
diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
index e3399bc0..515c3dcb 100644
--- a/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
+++ b/src/TensorFlowNET.Core/NumPy/NDArray.Implicit.cs
@@ -8,7 +8,7 @@ namespace Tensorflow.NumPy
{
public void Deconstruct(out byte blue, out byte green, out byte red)
{
- var data = Data();
+ var data = ToArray();
blue = data[0];
green = data[1];
red = data[2];
@@ -17,23 +17,23 @@ namespace Tensorflow.NumPy
public static implicit operator NDArray(Array array)
=> new NDArray(array);
- public static implicit operator bool(NDArray nd)
- => nd._tensor.ToArray()[0];
+ public unsafe static implicit operator bool(NDArray nd)
+ => *(bool*)nd.data;
- public static implicit operator byte(NDArray nd)
- => nd._tensor.ToArray()[0];
+ public unsafe static implicit operator byte(NDArray nd)
+ => *(byte*)nd.data;
- public static implicit operator byte[](NDArray nd)
- => nd.ToByteArray();
+ public unsafe static implicit operator int(NDArray nd)
+ => *(int*)nd.data;
- public static implicit operator int(NDArray nd)
- => nd._tensor.ToArray()[0];
+ public unsafe static implicit operator long(NDArray nd)
+ => *(long*)nd.data;
- public static implicit operator float(NDArray nd)
- => nd._tensor.ToArray()[0];
+ public unsafe static implicit operator float(NDArray nd)
+ => *(float*)nd.data;
- public static implicit operator double(NDArray nd)
- => nd._tensor.ToArray()[0];
+ public unsafe static implicit operator double(NDArray nd)
+ => *(double*)nd.data;
public static implicit operator NDArray(bool value)
=> new NDArray(value);
diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
index 1cfcdb38..be340789 100644
--- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
+++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
@@ -8,76 +8,78 @@ namespace Tensorflow.NumPy
{
public partial class NDArray
{
- public NDArray this[int index]
+ public NDArray this[params int[] index]
{
- get
+ get => _tensor[index.Select(x => new Slice
{
- return _tensor[index];
- }
+ Start = x,
+ Stop = x + 1,
+ IsIndex = true
+ }).ToArray()];
- set
+ set => SetData(index.Select(x => new Slice
{
+ Start = x,
+ Stop = x + 1,
+ IsIndex = true
+ }), value);
+ }
- }
+ public NDArray this[params Slice[] slices]
+ {
+ get => _tensor[slices];
+ set => SetData(slices, value);
}
- public NDArray this[params int[] index]
+ public NDArray this[NDArray mask]
{
get
{
- return _tensor[index.Select(x => new Slice(x, x + 1)).ToArray()];
+ throw new NotImplementedException("");
}
set
{
- var offset = ShapeHelper.GetOffset(shape, index);
- unsafe
- {
- if (dtype == TF_DataType.TF_BOOL)
- *((bool*)data + offset) = value;
- else if (dtype == TF_DataType.TF_UINT8)
- *((byte*)data + offset) = value;
- else if (dtype == TF_DataType.TF_INT32)
- *((int*)data + offset) = value;
- else if (dtype == TF_DataType.TF_INT64)
- *((long*)data + offset) = value;
- else if (dtype == TF_DataType.TF_FLOAT)
- *((float*)data + offset) = value;
- else if (dtype == TF_DataType.TF_DOUBLE)
- *((double*)data + offset) = value;
- }
+ throw new NotImplementedException("");
}
}
- public NDArray this[params Slice[] slices]
+ void SetData(IEnumerable slices, NDArray array)
+ => SetData(slices, array, -1, slices.Select(x => 0).ToArray());
+
+ void SetData(IEnumerable slices, NDArray array, int currentNDim, int[] indices)
{
- get
- {
- return _tensor[slices];
- }
+ if (dtype != array.dtype)
+ throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned.");
- set
+ if (!slices.Any())
+ return;
+
+ var slice = slices.First();
+
+ if (slices.Count() == 1)
{
- var pos = _tensor[slices];
- var len = value.bytesize;
+
+ if (slice.Step != 1)
+ throw new NotImplementedException("");
+
+ indices[indices.Length - 1] = slice.Start ?? 0;
+ var offset = (ulong)ShapeHelper.GetOffset(shape, indices);
+ var bytesize = array.bytesize;
unsafe
{
- System.Buffer.MemoryCopy(value.data.ToPointer(), pos.TensorDataPointer.ToPointer(), len, len);
+ var dst = (byte*)data + offset * dtypesize;
+ System.Buffer.MemoryCopy(array.data.ToPointer(), dst, bytesize, bytesize);
}
- // _tensor[slices].assign(constant_op.constant(value));
- }
- }
- public NDArray this[NDArray mask]
- {
- get
- {
- throw new NotImplementedException("");
+ return;
}
- set
+ currentNDim++;
+ for (var i = slice.Start ?? 0; i < slice.Stop; i++)
{
-
+ indices[currentNDim] = i;
+ SetData(slices.Skip(1), array, currentNDim, indices);
}
}
}
diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs
new file mode 100644
index 00000000..960fae4b
--- /dev/null
+++ b/src/TensorFlowNET.Core/NumPy/NDArray.Operators.cs
@@ -0,0 +1,16 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.NumPy
+{
+ public partial class NDArray
+ {
+ public static NDArray operator +(NDArray lhs, NDArray rhs) => lhs.Tensor + rhs.Tensor;
+ public static NDArray operator -(NDArray lhs, NDArray rhs) => lhs.Tensor - rhs.Tensor;
+ public static NDArray operator *(NDArray lhs, NDArray rhs) => lhs.Tensor * rhs.Tensor;
+ public static NDArray operator /(NDArray lhs, NDArray rhs) => lhs.Tensor / rhs.Tensor;
+ }
+}
diff --git a/src/TensorFlowNET.Core/Numpy/NDArray.cs b/src/TensorFlowNET.Core/Numpy/NDArray.cs
index ff8b1d98..7293093e 100644
--- a/src/TensorFlowNET.Core/Numpy/NDArray.cs
+++ b/src/TensorFlowNET.Core/Numpy/NDArray.cs
@@ -25,6 +25,7 @@ namespace Tensorflow.NumPy
public partial class NDArray
{
Tensor _tensor;
+ public Tensor Tensor => _tensor;
public TF_DataType dtype => _tensor.dtype;
public ulong size => _tensor.size;
public ulong dtypesize => _tensor.dtypesize;
@@ -47,15 +48,12 @@ namespace Tensorflow.NumPy
public ValueType GetValue(params int[] indices)
=> throw new NotImplementedException("");
- public void SetData(object value, params int[] indices)
- => throw new NotImplementedException("");
-
public NDIterator AsIterator(bool autoreset = false) where T : unmanaged
=> throw new NotImplementedException("");
public bool HasNext() => throw new NotImplementedException("");
public T MoveNext() => throw new NotImplementedException("");
- public NDArray reshape(Shape newshape) => new NDArray(_tensor, newshape);
+ public NDArray reshape(Shape newshape) => new NDArray(tf.reshape(_tensor, newshape));
public NDArray astype(Type type) => new NDArray(math_ops.cast(_tensor, type.as_tf_dtype()));
public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(_tensor, dtype));
public NDArray ravel() => throw new NotImplementedException("");
diff --git a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
index fdc7fbbe..c9c44b8f 100644
--- a/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
+++ b/src/TensorFlowNET.Keras/Engine/DataAdapters/DataHandler.cs
@@ -14,13 +14,13 @@ namespace Tensorflow.Keras.Engine.DataAdapters
IDataAdapter _adapter;
public IDataAdapter DataAdapter => _adapter;
IDatasetV2 _dataset;
- int _inferred_steps;
- public int Inferredsteps => _inferred_steps;
- int _current_step;
- int _step_increment;
- public int StepIncrement => _step_increment;
+ long _inferred_steps;
+ public long Inferredsteps => _inferred_steps;
+ long _current_step;
+ long _step_increment;
+ public long StepIncrement => _step_increment;
bool _insufficient_data;
- int _steps_per_execution_value;
+ long _steps_per_execution_value;
int _initial_epoch => args.InitialEpoch;
int _epochs => args.Epochs;
IVariableV1 _steps_per_execution;
@@ -30,8 +30,8 @@ namespace Tensorflow.Keras.Engine.DataAdapters
this.args = args;
if (args.StepsPerExecution == null)
{
- _steps_per_execution = tf.Variable(1);
- _steps_per_execution_value = 1;
+ _steps_per_execution = tf.Variable(1L);
+ _steps_per_execution_value = 1L;
}
else
{
@@ -103,7 +103,7 @@ namespace Tensorflow.Keras.Engine.DataAdapters
// _adapter.on_epoch_end()
}
- public IEnumerable steps()
+ public IEnumerable steps()
{
_current_step = 0;
while (_current_step < _inferred_steps)
diff --git a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
index 4a630e0d..2bfb632a 100644
--- a/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/PreprocessingTests.cs
@@ -229,7 +229,7 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(9, oov_count);
}
- [TestMethod, Ignore("slice assign doesn't work")]
+ [TestMethod]
public void PadSequencesWithDefaults()
{
var tokenizer = keras.preprocessing.text.Tokenizer(oov_token: OOV);
@@ -241,12 +241,12 @@ namespace TensorFlowNET.Keras.UnitTest
Assert.AreEqual(4, padded.dims[0]);
Assert.AreEqual(22, padded.dims[1]);
- Assert.AreEqual(tokenizer.word_index["worst"], padded[0, 19]);
+ Assert.AreEqual(padded[0, 19], tokenizer.word_index["worst"]);
for (var i = 0; i < 8; i++)
- Assert.AreEqual(0, padded[0, i]);
- Assert.AreEqual(tokenizer.word_index["proud"], padded[1, 10]);
+ Assert.AreEqual(padded[0, i], 0);
+ Assert.AreEqual(padded[1, 10], tokenizer.word_index["proud"]);
for (var i = 0; i < 20; i++)
- Assert.AreNotEqual(0, padded[1, i]);
+ Assert.AreNotEqual(padded[1, i], 0);
}
[TestMethod, Ignore("slice assign doesn't work")]
diff --git a/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs b/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs
index 3694fd8e..1768eab5 100644
--- a/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs
+++ b/test/TensorFlowNET.UnitTest/Basics/SessionTest.cs
@@ -38,7 +38,7 @@ namespace TensorFlowNET.UnitTest
var c = tf.strings.substr(a, 4, 8);
using (var sess = tf.Session())
{
- var result = UTF8Encoding.UTF8.GetString((byte[])c.eval(sess));
+ var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray());
Console.WriteLine(result);
result.Should().Be("heythere");
}
@@ -55,7 +55,7 @@ namespace TensorFlowNET.UnitTest
var c = tf.strings.substr(a, 0, size - 5000);
using (var sess = tf.Session())
{
- var result = UTF8Encoding.UTF8.GetString((byte[])c.eval(sess));
+ var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray());
Console.WriteLine(result);
result.Should().HaveLength(size - 5000).And.ContainAll("a");
}
diff --git a/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs
new file mode 100644
index 00000000..8100be3f
--- /dev/null
+++ b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs
@@ -0,0 +1,57 @@
+using Microsoft.VisualStudio.TestTools.UnitTesting;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using Tensorflow;
+using Tensorflow.NumPy;
+
+namespace TensorFlowNET.UnitTest.NumPy
+{
+ ///
+ /// https://numpy.org/doc/stable/user/basics.indexing.html
+ ///
+ [TestClass]
+ public class ArrayIndexingTest : EagerModeTestBase
+ {
+ [TestMethod]
+ public void int_params()
+ {
+ var x = np.arange(24).reshape((2, 3, 4));
+ x[1, 2, 3] = 1;
+ var y = x[1, 2, 3];
+ Assert.AreEqual(y.shape, Shape.Scalar);
+ Assert.AreEqual(y, 1);
+
+ x[0, 0] = new[] { 3, 1, 1, 2 };
+ y = x[0, 0];
+ Assert.AreEqual(y.shape, 4);
+ Assert.AreEqual(y, new[] { 3, 1, 1, 2 });
+
+ y = x[0];
+ Assert.AreEqual(y.shape, (3, 4));
+
+ var z = np.arange(12).reshape((3, 4));
+ x[1] = z;
+ Assert.AreEqual(x[1], z);
+ }
+
+ [TestMethod]
+ public void slice_params()
+ {
+ var x = np.arange(12).reshape((3, 4));
+ var y = x[new Slice(0, 1), new Slice(2)];
+ Assert.AreEqual(y.shape, (1, 2));
+ Assert.AreEqual(y, np.array(new[] { 2, 3 }).reshape((1, 2)));
+ }
+
+ [TestMethod]
+ public void slice_string_params()
+ {
+ var x = np.arange(12).reshape((3, 4));
+ var y = x[Slice.ParseSlices("0:1,2:")];
+ Assert.AreEqual(y.shape, (1, 2));
+ Assert.AreEqual(y, np.array(new[] { 2, 3 }).reshape((1, 2)));
+ }
+ }
+}
diff --git a/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs
index ddaa7c1d..f58ed389 100644
--- a/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs
+++ b/test/TensorFlowNET.UnitTest/Numpy/Array.Creation.Test.cs
@@ -5,13 +5,13 @@ using System.Linq;
using System.Text;
using Tensorflow.NumPy;
-namespace TensorFlowNET.UnitTest.Numpy
+namespace TensorFlowNET.UnitTest.NumPy
{
///
/// https://numpy.org/doc/stable/reference/routines.array-creation.html
///
[TestClass]
- public class NumpyArrayCreationTest : EagerModeTestBase
+ public class ArrayCreationTest : EagerModeTestBase
{
[TestMethod]
public void empty_zeros_ones_full()
diff --git a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
index ea4930fb..3c1dbcf3 100644
--- a/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
+++ b/test/TensorFlowNET.UnitTest/Numpy/Math.Test.cs
@@ -6,13 +6,13 @@ using System.Text;
using Tensorflow;
using Tensorflow.NumPy;
-namespace TensorFlowNET.UnitTest.Numpy
+namespace TensorFlowNET.UnitTest.NumPy
{
///
/// https://numpy.org/doc/stable/reference/generated/numpy.prod.html
///
[TestClass]
- public class NumpyMathTest : EagerModeTestBase
+ public class MathTest : EagerModeTestBase
{
[TestMethod]
public void prod()