Browse Source

ndarray SetData

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
8ab1fe99e7
4 changed files with 62 additions and 36 deletions
  1. +26
    -35
      src/TensorFlowNET.Core/NumPy/NDArray.Index.cs
  2. +17
    -0
      src/TensorFlowNET.Core/NumPy/ShapeHelper.cs
  3. +4
    -1
      src/TensorFlowNET.Core/NumPy/SliceHelper.cs
  4. +15
    -0
      test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs

+ 26
- 35
src/TensorFlowNET.Core/NumPy/NDArray.Index.cs View File

@@ -170,62 +170,53 @@ namespace Tensorflow.NumPy
}

void SetData(IEnumerable<Slice> slices, NDArray array)
=> SetData(slices, array, -1, slices.Select(x => 0).ToArray());
=> SetData(array, data, slices.ToArray(), new int[shape.ndim].ToArray(), -1);

void SetData(IEnumerable<Slice> slices, NDArray array, int currentNDim, int[] indices)
unsafe void SetData(NDArray src, IntPtr dst, Slice[] slices, int[] indices, int currentNDim)
{
if (dtype != array.dtype)
array = array.astype(dtype);
if (dtype != src.dtype)
src = src.astype(dtype);
// throw new ArrayTypeMismatchException($"Required dtype {dtype} but {array.dtype} is assigned.");

if (!slices.Any())
return;

var newshape = ShapeHelper.GetShape(shape, slices.ToArray());
if(newshape.Equals(array.shape))
// first iteration
if(currentNDim == -1)
{
var offset = ShapeHelper.GetOffset(shape, slices.First().Start ?? 0);
unsafe
slices = SliceHelper.AlignWithShape(shape, slices);
if (!shape.Equals(src.shape))
{
var dst = (byte*)data + (ulong)offset * dtypesize;
System.Buffer.MemoryCopy(array.data.ToPointer(), dst, array.bytesize, array.bytesize);
var newShape = ShapeHelper.AlignWithShape(shape, src.shape);
src = src.reshape(newShape);
}
return;
}


var slice = slices.First();

if (slices.Count() == 1)
// last dimension
if (currentNDim == ndim - 1)
{

if (slice.Step != 1)
throw new NotImplementedException("slice.step should == 1");

if (slice.Start < 0)
throw new NotImplementedException("slice.start should > -1");

indices[indices.Length - 1] = slice.Start ?? 0;
var offset = (ulong)ShapeHelper.GetOffset(shape, indices);
var bytesize = array.bytesize;
unsafe
{
var dst = (byte*)data + offset * dtypesize;
System.Buffer.MemoryCopy(array.data.ToPointer(), dst, bytesize, bytesize);
}

System.Buffer.MemoryCopy(src.data.ToPointer(), dst.ToPointer(), src.bytesize, src.bytesize);
return;
}

currentNDim++;
if (slice.Stop == null)
slice.Stop = (int)dims[currentNDim];
var slice = slices[currentNDim];
var start = slice.Start ?? 0;
var stop = slice.Stop ?? (int)dims[currentNDim];
var step = slice.Step;

for (var i = slice.Start ?? 0; i < slice.Stop; i++)
for (var i = start; i < stop; i += step)
{
indices[currentNDim] = i;
SetData(slices.Skip(1), array, currentNDim, indices);
var offset = (int)ShapeHelper.GetOffset(shape, indices);
dst = data + offset * (int)dtypesize;
var srcIndex = (i - start) / step;
SetData(src[srcIndex], dst, slices, indices, currentNDim);
}

// reset indices
indices[currentNDim] = 0;
}
}
}

+ 17
- 0
src/TensorFlowNET.Core/NumPy/ShapeHelper.cs View File

@@ -69,6 +69,23 @@ namespace Tensorflow.NumPy
return new Shape(return_dims.ToArray());
}

public static Shape AlignWithShape(Shape shape, Shape preShape)
{
if (shape.ndim == preShape.ndim)
return preShape;

var newShape = shape.dims.Select(x => 1L).ToArray();
if (preShape.IsScalar)
return new Shape(newShape);

for (int i = 0; i < preShape.ndim; i++)
{
newShape[i + shape.ndim - preShape.ndim] = preShape[i];
}

return new Shape(newShape);
}

public static bool Equals(Shape shape, object target)
{
switch (target)


+ 4
- 1
src/TensorFlowNET.Core/NumPy/SliceHelper.cs View File

@@ -9,8 +9,11 @@ namespace Tensorflow.NumPy
{
public static Slice[] AlignWithShape(Shape shape, Slice[] slices)
{
// align slices
var ndim = shape.ndim;
if (ndim == slices.Length)
return slices;

// align slices
var new_slices = new List<Slice>();
var slice_index = 0;



+ 15
- 0
test/TensorFlowNET.UnitTest/NumPy/Array.Indexing.Test.cs View File

@@ -116,5 +116,20 @@ namespace TensorFlowNET.UnitTest.NumPy
i++;
}
}

[TestMethod]
public void slice_step()
{
var array = np.arange(32).reshape((4, 8));
var s1 = array[Slice.All, new Slice(2, 5, 2)] + 1;
Assert.AreEqual(s1.shape, (4, 2));
var expected = new[] { 3, 5, 11, 13, 19, 21, 27, 29 };
Assert.IsTrue(Enumerable.SequenceEqual(expected, s1.ToArray<int>()));
array[Slice.All, new Slice(2, 5, 2)] = s1;
Assert.AreEqual(array[0], new[] { 0, 1, 3, 3, 5, 5, 6, 7 });
Assert.AreEqual(array[1], new[] { 8, 9, 11, 11, 13, 13, 14, 15 });
Assert.AreEqual(array[2], new[] { 16, 17, 19, 19, 21, 21, 22, 23 });
Assert.AreEqual(array[3], new[] { 24, 25, 27, 27, 29, 29, 30, 31 });
}
}
}

Loading…
Cancel
Save