diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
index 526ccb8d..7ac5063c 100644
--- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
+++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
@@ -27,7 +27,8 @@ Docs: https://tensorflownet.readthedocs.io
6. Add Local Response Normalization.
7. Add tf.image related APIs.
8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor.
-9. MultiThread is safe.
+9. MultiThread is safe.
+10. Support n-dim indexing for tensor.
7.3
0.11.2.0
LICENSE
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
index 244c0ccf..6632550f 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
@@ -17,6 +17,7 @@
using NumSharp;
using System;
using System.Collections.Generic;
+using System.Linq;
using System.Text;
using static Tensorflow.Binding;
@@ -24,11 +25,84 @@ namespace Tensorflow
{
public partial class Tensor
{
- public Tensor this[int idx]
+ public Tensor this[int idx] => slice(idx);
+
+ public Tensor this[params string[] slices]
{
get
{
- return slice(idx);
+ var slice_spec = slices.Select(x => x == null ? null : new Slice(x)).ToArray();
+ var begin = new List();
+ var end = new List();
+ var strides = new List();
+
+ var index = 0;
+ var (new_axis_mask, shrink_axis_mask) = (0, 0);
+ var (begin_mask, end_mask) = (0, 0);
+ var ellipsis_mask = 0;
+
+ foreach (var s in slice_spec)
+ {
+ if(s == null)
+ {
+ begin.Add(0);
+ end.Add(0);
+ strides.Add(0);
+ new_axis_mask |= (1 << index);
+ }
+ else
+ {
+ if (s.Start.HasValue)
+ {
+ begin.Add(s.Start.Value);
+ }
+ else
+ {
+ begin.Add(0);
+ begin_mask |= (1 << index);
+ }
+
+ if (s.Stop.HasValue)
+ {
+ end.Add(s.Stop.Value);
+ }
+ else
+ {
+ end.Add(0);
+ end_mask |= (1 << index);
+ }
+
+ strides.Add(s.Step);
+ }
+
+ index += 1;
+ }
+
+ return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope =>
+ {
+ string name = scope;
+ if (begin != null)
+ {
+ var (packed_begin, packed_end, packed_strides) =
+ (array_ops.stack(begin.ToArray()),
+ array_ops.stack(end.ToArray()),
+ array_ops.stack(strides.ToArray()));
+
+ return gen_array_ops.strided_slice(
+ this,
+ packed_begin,
+ packed_end,
+ packed_strides,
+ begin_mask: begin_mask,
+ end_mask: end_mask,
+ shrink_axis_mask: shrink_axis_mask,
+ new_axis_mask: new_axis_mask,
+ ellipsis_mask: ellipsis_mask,
+ name: name);
+ }
+
+ throw new NotImplementedException("");
+ });
}
}