Browse Source

support n-dim indexing. #368

tags/v0.12
Oceania2018 6 years ago
parent
commit
d44f044c86
2 changed files with 78 additions and 3 deletions
  1. +2
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  2. +76
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Index.cs

+ 2
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -27,7 +27,8 @@ Docs: https://tensorflownet.readthedocs.io</Description>
6. Add Local Response Normalization. 6. Add Local Response Normalization.
7. Add tf.image related APIs. 7. Add tf.image related APIs.
8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. 8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor.
9. MultiThread is safe.</PackageReleaseNotes>
9. MultiThread is safe.
10. Support n-dim indexing for tensor.</PackageReleaseNotes>
<LangVersion>7.3</LangVersion> <LangVersion>7.3</LangVersion>
<FileVersion>0.11.2.0</FileVersion> <FileVersion>0.11.2.0</FileVersion>
<PackageLicenseFile>LICENSE</PackageLicenseFile> <PackageLicenseFile>LICENSE</PackageLicenseFile>


+ 76
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Index.cs View File

@@ -17,6 +17,7 @@
using NumSharp; using NumSharp;
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -24,11 +25,84 @@ namespace Tensorflow
{ {
public partial class Tensor public partial class Tensor
{ {
public Tensor this[int idx]
public Tensor this[int idx] => slice(idx);

public Tensor this[params string[] slices]
{ {
get get
{ {
return slice(idx);
var slice_spec = slices.Select(x => x == null ? null : new Slice(x)).ToArray();
var begin = new List<int>();
var end = new List<int>();
var strides = new List<int>();

var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0;

foreach (var s in slice_spec)
{
if(s == null)
{
begin.Add(0);
end.Add(0);
strides.Add(0);
new_axis_mask |= (1 << index);
}
else
{
if (s.Start.HasValue)
{
begin.Add(s.Start.Value);
}
else
{
begin.Add(0);
begin_mask |= (1 << index);
}

if (s.Stop.HasValue)
{
end.Add(s.Stop.Value);
}
else
{
end.Add(0);
end_mask |= (1 << index);
}

strides.Add(s.Step);
}

index += 1;
}

return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
{
string name = scope;
if (begin != null)
{
var (packed_begin, packed_end, packed_strides) =
(array_ops.stack(begin.ToArray()),
array_ops.stack(end.ToArray()),
array_ops.stack(strides.ToArray()));

return gen_array_ops.strided_slice(
this,
packed_begin,
packed_end,
packed_strides,
begin_mask: begin_mask,
end_mask: end_mask,
shrink_axis_mask: shrink_axis_mask,
new_axis_mask: new_axis_mask,
ellipsis_mask: ellipsis_mask,
name: name);
}

throw new NotImplementedException("");
});
} }
} }




Loading…
Cancel
Save