From 13e4e3e20e185cd44a6376015267655ca32c8395 Mon Sep 17 00:00:00 2001 From: Oceania2018 Date: Sat, 8 May 2021 19:07:22 -0500 Subject: [PATCH] Fix slice in Tensor. --- .../Operations/array_ops.cs | 3 +-- .../Operations/string_ops.cs | 5 ++-- .../Tensors/Ragged/RaggedTensor.cs | 22 +++++++++++++++- .../Tensors/Tensor.Index.cs | 4 +-- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 25 +++++++++++-------- src/TensorFlowNET.Core/Tensors/tf.constant.cs | 3 +++ 6 files changed, 45 insertions(+), 17 deletions(-) diff --git a/src/TensorFlowNET.Core/Operations/array_ops.cs b/src/TensorFlowNET.Core/Operations/array_ops.cs index d683c0be..ffc1c9a7 100644 --- a/src/TensorFlowNET.Core/Operations/array_ops.cs +++ b/src/TensorFlowNET.Core/Operations/array_ops.cs @@ -191,14 +191,13 @@ namespace Tensorflow private static Tensor _constant_if_small(T value, TensorShape shape, TF_DataType dtype, string name) { - Tensor shape_t = null; if (shape.size < 1000) { return constant_op.constant(value, shape: shape, dtype: dtype, name: name); } else { - shape_t = constant_op._tensor_shape_tensor_conversion_function(shape); + var shape_t = constant_op._tensor_shape_tensor_conversion_function(shape); var c = constant_op.constant(0, dtype: dtype); return gen_array_ops.fill(shape_t, c, name: name); } diff --git a/src/TensorFlowNET.Core/Operations/string_ops.cs b/src/TensorFlowNET.Core/Operations/string_ops.cs index 0a4169b6..2fe4a7f0 100644 --- a/src/TensorFlowNET.Core/Operations/string_ops.cs +++ b/src/TensorFlowNET.Core/Operations/string_ops.cs @@ -80,10 +80,11 @@ namespace Tensorflow var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING); if(input.rank == 0) { - return string_split_v2(array_ops.stack(new[] { input }), + var parts = string_split_v2(array_ops.stack(new[] { input }), sep: sep, maxsplit: maxsplit, - name: name)[0]; + name: name); + return parts; } var result = tf.Context.ExecuteOp("StringSplitV2", name, diff --git a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs index aadba0c5..9c7d96f8 100644 --- a/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs +++ b/src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs @@ -44,6 +44,18 @@ namespace Tensorflow } } + public Tensor this[int index] + { + get + { + return tf_with(ops.name_scope(null, "RaggedGetItem"), scope => + { + string name = scope; + return _ragged_getitem(index); + }); + } + } + public RaggedTensor this[params Slice[] slices] { get @@ -61,6 +73,14 @@ namespace Tensorflow } } + Tensor _ragged_getitem(int row_key) + { + var starts = _row_splits[":-1"]; + var limits = _row_splits["1:"]; + var row = _values[starts[row_key], limits[row_key]]; + return row; + } + RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices) { return input; @@ -134,7 +154,7 @@ namespace Tensorflow => new[] { _row_splits }; public override string ToString() - => $"tf.RaggedTensor: shape={_values.TensorShape} [{string.Join(", ", _values.StringData().Take(10))}]"; + => $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]"; public static implicit operator Tensor(RaggedTensor indexedSlices) => indexedSlices._to_variant(); diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs index a12f1fb5..4db3266c 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Index.cs @@ -120,11 +120,11 @@ namespace Tensorflow }); } - public Tensor this[params Tensor[] slice] + public Tensor this[Tensor start, Tensor stop = null, Tensor step = null] { get { - var args = tensor_util.ParseSlices(slice); + var args = tensor_util.ParseSlices(start, stop: stop, step: step); return tf_with(ops.name_scope(null, "strided_slice", args), scope => { diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index ccc5c31c..110f38df 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -674,25 +674,30 @@ would not be rank 1.", tensor.op.get_attr("axis"))); }; } - public static ParsedSliceArgs ParseSlices(Tensor[] slices) + public static ParsedSliceArgs ParseSlices(Tensor start, Tensor stop = null, Tensor step = null) { var begin = new List(); var end = new List(); var strides = new List(); - var index = 0; + // var index = 0; var (new_axis_mask, shrink_axis_mask) = (0, 0); var (begin_mask, end_mask) = (0, 0); var ellipsis_mask = 0; - foreach (var s in slices) - { - begin.Add(s); - end.Add(s + 1); - shrink_axis_mask |= (1 << index); - strides.Add(tf.constant(1, dtype: s.dtype)); - index += 1; - } + begin.Add(start); + + if (stop == null) + end.Add(start + 1); + else + end.Add(stop); + + // shrink_axis_mask |= (1 << index); + + if (step == null) + strides.Add(tf.constant(1, dtype: start.dtype)); + else + strides.Add(step); return new ParsedSliceArgs { diff --git a/src/TensorFlowNET.Core/Tensors/tf.constant.cs b/src/TensorFlowNET.Core/Tensors/tf.constant.cs index baa422a4..291e8d0c 100644 --- a/src/TensorFlowNET.Core/Tensors/tf.constant.cs +++ b/src/TensorFlowNET.Core/Tensors/tf.constant.cs @@ -40,6 +40,9 @@ namespace Tensorflow public Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) => array_ops.zeros(shape, dtype, name); + public Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) + => array_ops.zeros(shape, dtype, name); + public Tensor ones(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) => array_ops.ones(shape, dtype, name);