Browse Source

Fix slice in Tensor.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
13e4e3e20e
6 changed files with 45 additions and 17 deletions
  1. +1
    -2
      src/TensorFlowNET.Core/Operations/array_ops.cs
  2. +3
    -2
      src/TensorFlowNET.Core/Operations/string_ops.cs
  3. +21
    -1
      src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
  5. +15
    -10
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  6. +3
    -0
      src/TensorFlowNET.Core/Tensors/tf.constant.cs

+ 1
- 2
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -191,14 +191,13 @@ namespace Tensorflow

private static Tensor _constant_if_small<T>(T value, TensorShape shape, TF_DataType dtype, string name)
{
Tensor shape_t = null;
if (shape.size < 1000)
{
return constant_op.constant(value, shape: shape, dtype: dtype, name: name);
}
else
{
shape_t = constant_op._tensor_shape_tensor_conversion_function(shape);
var shape_t = constant_op._tensor_shape_tensor_conversion_function(shape);
var c = constant_op.constant(0, dtype: dtype);
return gen_array_ops.fill(shape_t, c, name: name);
}


+ 3
- 2
src/TensorFlowNET.Core/Operations/string_ops.cs View File

@@ -80,10 +80,11 @@ namespace Tensorflow
var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING);
if(input.rank == 0)
{
return string_split_v2(array_ops.stack(new[] { input }),
var parts = string_split_v2(array_ops.stack(new[] { input }),
sep: sep,
maxsplit: maxsplit,
name: name)[0];
name: name);
return parts;
}
var result = tf.Context.ExecuteOp("StringSplitV2", name,


+ 21
- 1
src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs View File

@@ -44,6 +44,18 @@ namespace Tensorflow
}
}

public Tensor this[int index]
{
get
{
return tf_with(ops.name_scope(null, "RaggedGetItem"), scope =>
{
string name = scope;
return _ragged_getitem(index);
});
}
}

public RaggedTensor this[params Slice[] slices]
{
get
@@ -61,6 +73,14 @@ namespace Tensorflow
}
}

Tensor _ragged_getitem(int row_key)
{
var starts = _row_splits[":-1"];
var limits = _row_splits["1:"];
var row = _values[starts[row_key], limits[row_key]];
return row;
}

RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices)
{
return input;
@@ -134,7 +154,7 @@ namespace Tensorflow
=> new[] { _row_splits };

public override string ToString()
=> $"tf.RaggedTensor: shape={_values.TensorShape} [{string.Join(", ", _values.StringData().Take(10))}]";
=> $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]";

public static implicit operator Tensor(RaggedTensor indexedSlices)
=> indexedSlices._to_variant();


+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Index.cs View File

@@ -120,11 +120,11 @@ namespace Tensorflow
});
}

public Tensor this[params Tensor[] slice]
public Tensor this[Tensor start, Tensor stop = null, Tensor step = null]
{
get
{
var args = tensor_util.ParseSlices(slice);
var args = tensor_util.ParseSlices(start, stop: stop, step: step);

return tf_with(ops.name_scope(null, "strided_slice", args), scope =>
{


+ 15
- 10
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -674,25 +674,30 @@ would not be rank 1.", tensor.op.get_attr("axis")));
};
}

public static ParsedSliceArgs ParseSlices(Tensor[] slices)
public static ParsedSliceArgs ParseSlices(Tensor start, Tensor stop = null, Tensor step = null)
{
var begin = new List<Tensor>();
var end = new List<Tensor>();
var strides = new List<Tensor>();

var index = 0;
// var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0;

foreach (var s in slices)
{
begin.Add(s);
end.Add(s + 1);
shrink_axis_mask |= (1 << index);
strides.Add(tf.constant(1, dtype: s.dtype));
index += 1;
}
begin.Add(start);

if (stop == null)
end.Add(start + 1);
else
end.Add(stop);

// shrink_axis_mask |= (1 << index);

if (step == null)
strides.Add(tf.constant(1, dtype: start.dtype));
else
strides.Add(step);

return new ParsedSliceArgs
{


+ 3
- 0
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -40,6 +40,9 @@ namespace Tensorflow
public Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.zeros(shape, dtype, name);

public Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.zeros(shape, dtype, name);

public Tensor ones(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.ones(shape, dtype, name);



Loading…
Cancel
Save