diff --git a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj index 526ccb8d..7ac5063c 100644 --- a/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj +++ b/src/TensorFlowNET.Core/TensorFlowNET.Core.csproj @@ -27,7 +27,8 @@ Docs: https://tensorflownet.readthedocs.io 6. Add Local Response Normalization. 7. Add tf.image related APIs. 8. Add tf.random_normal, tf.constant, tf.pad, tf.shape, tf.image.resize_nearest_neighbor. -9. MultiThread is safe. +9. MultiThread is safe. +10. Support n-dim indexing for tensor. 7.3 0.11.2.0 LICENSE diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs index 244c0ccf..6632550f 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs @@ -17,6 +17,7 @@ using NumSharp; using System; using System.Collections.Generic; +using System.Linq; using System.Text; using static Tensorflow.Binding; @@ -24,11 +25,84 @@ namespace Tensorflow { public partial class Tensor { - public Tensor this[int idx] + public Tensor this[int idx] => slice(idx); + + public Tensor this[params string[] slices] { get { - return slice(idx); + var slice_spec = slices.Select(x => x == null ? null : new Slice(x)).ToArray(); + var begin = new List(); + var end = new List(); + var strides = new List(); + + var index = 0; + var (new_axis_mask, shrink_axis_mask) = (0, 0); + var (begin_mask, end_mask) = (0, 0); + var ellipsis_mask = 0; + + foreach (var s in slice_spec) + { + if(s == null) + { + begin.Add(0); + end.Add(0); + strides.Add(0); + new_axis_mask |= (1 << index); + } + else + { + if (s.Start.HasValue) + { + begin.Add(s.Start.Value); + } + else + { + begin.Add(0); + begin_mask |= (1 << index); + } + + if (s.Stop.HasValue) + { + end.Add(s.Stop.Value); + } + else + { + end.Add(0); + end_mask |= (1 << index); + } + + strides.Add(s.Step); + } + + index += 1; + } + + return tf_with(ops.name_scope(null, "strided_slice", new { begin, end, strides }), scope => + { + string name = scope; + if (begin != null) + { + var (packed_begin, packed_end, packed_strides) = + (array_ops.stack(begin.ToArray()), + array_ops.stack(end.ToArray()), + array_ops.stack(strides.ToArray())); + + return gen_array_ops.strided_slice( + this, + packed_begin, + packed_end, + packed_strides, + begin_mask: begin_mask, + end_mask: end_mask, + shrink_axis_mask: shrink_axis_mask, + new_axis_mask: new_axis_mask, + ellipsis_mask: ellipsis_mask, + name: name); + } + + throw new NotImplementedException(""); + }); } }