From 8ab1fe99e786f1badf47e36089aec2c73ba8303a Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 31 Jul 2021 23:50:07 -0500 Subject: [PATCH] ndarray SetData --- src/TensorFlowNET.Core/NumPy/NDArray.Index.cs | 61 ++++++++----------- src/TensorFlowNET.Core/NumPy/ShapeHelper.cs | 17 ++++++ src/TensorFlowNET.Core/NumPy/SliceHelper.cs | 5 +- .../NumPy/Array.Indexing.Test.cs | 15 +++++ 4 files changed, 62 insertions(+), 36 deletions(-) diff --git a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs index 0b751c39..0073b4f5 100644 --- a/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs +++ b/src/TensorFlowNET.Core/NumPy/NDArray.Index.cs @@ -170,62 +170,53 @@ namespace Tensorflow.NumPy } void SetData(IEnumerable slices, NDArray array) - => SetData(slices, array, -1, slices.Select(x => 0).ToArray()); + => SetData(array, data, slices.ToArray(), new int[shape.ndim].ToArray(), -1); - void SetData(IEnumerable slices, NDArray array, int currentNDim, int[] indices) + unsafe void SetData(NDArray src, IntPtr dst, Slice[] slices, int[] indices, int currentNDim) { - if (dtype != array.dtype) - array = array.astype(dtype); + if (dtype != src.dtype) + src = src.astype(dtype); // throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned."); if (!slices.Any()) return; - var newshape = ShapeHelper.GetShape(shape, slices.ToArray()); - if(newshape.Equals(array.shape)) + // first iteration + if(currentNDim == -1) { - var offset = ShapeHelper.GetOffset(shape, slices.First().Start ?? 0); - unsafe + slices = SliceHelper.AlignWithShape(shape, slices); + if (!shape.Equals(src.shape)) { - var dst = (byte*)data + (ulong)offset * dtypesize; - System.Buffer.MemoryCopy(array.data.ToPointer(), dst, array.bytesize, array.bytesize); + var newShape = ShapeHelper.AlignWithShape(shape, src.shape); + src = src.reshape(newShape); } - return; } - - var slice = slices.First(); - - if (slices.Count() == 1) + // last dimension + if (currentNDim == ndim - 1) { - - if (slice.Step != 1) - throw new NotImplementedException("slice.step should == 1"); - - if (slice.Start < 0) - throw new NotImplementedException("slice.start should > -1"); - - indices[indices.Length - 1] = slice.Start ?? 0; - var offset = (ulong)ShapeHelper.GetOffset(shape, indices); - var bytesize = array.bytesize; - unsafe - { - var dst = (byte*)data + offset * dtypesize; - System.Buffer.MemoryCopy(array.data.ToPointer(), dst, bytesize, bytesize); - } - + System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize); return; } currentNDim++; - if (slice.Stop == null) - slice.Stop = (int)dims[currentNDim]; + var slice = slices[currentNDim]; + + var start = slice.Start ?? 0; + var stop = slice.Stop ?? (int)dims[currentNDim]; + var step = slice.Step; - for (var i = slice.Start ?? 0; i < slice.Stop; i++) + for (var i = start; i < stop; i += step) { indices[currentNDim] = i; - SetData(slices.Skip(1), array, currentNDim, indices); + var offset = (int)ShapeHelper.GetOffset(shape, indices); + dst = data + offset * (int)dtypesize; + var srcIndex = (i - start) / step; + SetData(src[srcIndex], dst, slices, indices, currentNDim); } + + // reset indices + indices[currentNDim] = 0; } } } diff --git a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs index dec43e83..4f4db76d 100644 --- a/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs +++ b/src/TensorFlowNET.Core/NumPy/ShapeHelper.cs @@ -69,6 +69,23 @@ namespace Tensorflow.NumPy return new Shape(return_dims.ToArray()); } + public static Shape AlignWithShape(Shape shape, Shape preShape) + { + if (shape.ndim == preShape.ndim) + return preShape; + + var newShape = shape.dims.Select(x => 1L).ToArray(); + if (preShape.IsScalar) + return new Shape(newShape); + + for (int i = 0; i < preShape.ndim; i++) + { + newShape[i + shape.ndim - preShape.ndim] = preShape[i]; + } + + return new Shape(newShape); + } + public static bool Equals(Shape shape, object target) { switch (target) diff --git a/src/TensorFlowNET.Core/NumPy/SliceHelper.cs b/src/TensorFlowNET.Core/NumPy/SliceHelper.cs index 1090ce27..d0739eff 100644 --- a/src/TensorFlowNET.Core/NumPy/SliceHelper.cs +++ b/src/TensorFlowNET.Core/NumPy/SliceHelper.cs @@ -9,8 +9,11 @@ namespace Tensorflow.NumPy { public static Slice[] AlignWithShape(Shape shape, Slice[] slices) { - // align slices var ndim = shape.ndim; + if (ndim == slices.Length) + return slices; + + // align slices var new_slices = new List(); var slice_index = 0; diff --git a/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs index 7c1a6d15..573c2fd2 100644 --- a/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs +++ b/test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs @@ -116,5 +116,20 @@ namespace TensorFlowNET.UnitTest.NumPy i++; } } + + [TestMethod] + public void slice_step() + { + var array = np.arange(32).reshape((4, 8)); + var s1 = array[Slice.All, new Slice(2, 5, 2)] + 1; + Assert.AreEqual(s1.shape, (4, 2)); + var expected = new[] { 3, 5, 11, 13, 19, 21, 27, 29 }; + Assert.IsTrue(Enumerable.SequenceEqual(expected, s1.ToArray())); + array[Slice.All, new Slice(2, 5, 2)] = s1; + Assert.AreEqual(array[0], new[] { 0, 1, 3, 3, 5, 5, 6, 7 }); + Assert.AreEqual(array[1], new[] { 8, 9, 11, 11, 13, 13, 14, 15 }); + Assert.AreEqual(array[2], new[] { 16, 17, 19, 19, 21, 21, 22, 23 }); + Assert.AreEqual(array[3], new[] { 24, 25, 27, 27, 29, 29, 30, 31 }); + } } }