Browse Source

Fix slice in Tensor.

tags/v0.60-tf.numpy
Oceania2018 4 years ago
parent
commit
13e4e3e20e
6 changed files with 45 additions and 17 deletions
  1. +1
    -2
      src/TensorFlowNET.Core/Operations/array_ops.cs
  2. +3
    -2
      src/TensorFlowNET.Core/Operations/string_ops.cs
  3. +21
    -1
      src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.Index.cs
  5. +15
    -10
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  6. +3
    -0
      src/TensorFlowNET.Core/Tensors/tf.constant.cs

+ 1
- 2
src/TensorFlowNET.Core/Operations/array_ops.cs View File

@@ -191,14 +191,13 @@ namespace Tensorflow


private static Tensor _constant_if_small<T>(T value, TensorShape shape, TF_DataType dtype, string name) private static Tensor _constant_if_small<T>(T value, TensorShape shape, TF_DataType dtype, string name)
{ {
Tensor shape_t = null;
if (shape.size < 1000) if (shape.size < 1000)
{ {
return constant_op.constant(value, shape: shape, dtype: dtype, name: name); return constant_op.constant(value, shape: shape, dtype: dtype, name: name);
} }
else else
{ {
shape_t = constant_op._tensor_shape_tensor_conversion_function(shape);
var shape_t = constant_op._tensor_shape_tensor_conversion_function(shape);
var c = constant_op.constant(0, dtype: dtype); var c = constant_op.constant(0, dtype: dtype);
return gen_array_ops.fill(shape_t, c, name: name); return gen_array_ops.fill(shape_t, c, name: name);
} }


+ 3
- 2
src/TensorFlowNET.Core/Operations/string_ops.cs View File

@@ -80,10 +80,11 @@ namespace Tensorflow
var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING); var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING);
if(input.rank == 0) if(input.rank == 0)
{ {
return string_split_v2(array_ops.stack(new[] { input }),
var parts = string_split_v2(array_ops.stack(new[] { input }),
sep: sep, sep: sep,
maxsplit: maxsplit, maxsplit: maxsplit,
name: name)[0];
name: name);
return parts;
} }
var result = tf.Context.ExecuteOp("StringSplitV2", name, var result = tf.Context.ExecuteOp("StringSplitV2", name,


+ 21
- 1
src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs View File

@@ -44,6 +44,18 @@ namespace Tensorflow
} }
} }


public Tensor this[int index]
{
get
{
return tf_with(ops.name_scope(null, "RaggedGetItem"), scope =>
{
string name = scope;
return _ragged_getitem(index);
});
}
}

public RaggedTensor this[params Slice[] slices] public RaggedTensor this[params Slice[] slices]
{ {
get get
@@ -61,6 +73,14 @@ namespace Tensorflow
} }
} }


Tensor _ragged_getitem(int row_key)
{
var starts = _row_splits[":-1"];
var limits = _row_splits["1:"];
var row = _values[starts[row_key], limits[row_key]];
return row;
}

RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices) RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices)
{ {
return input; return input;
@@ -134,7 +154,7 @@ namespace Tensorflow
=> new[] { _row_splits }; => new[] { _row_splits };


public override string ToString() public override string ToString()
=> $"tf.RaggedTensor: shape={_values.TensorShape} [{string.Join(", ", _values.StringData().Take(10))}]";
=> $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]";


public static implicit operator Tensor(RaggedTensor indexedSlices) public static implicit operator Tensor(RaggedTensor indexedSlices)
=> indexedSlices._to_variant(); => indexedSlices._to_variant();


+ 2
- 2
src/TensorFlowNET.Core/Tensors/Tensor.Index.cs View File

@@ -120,11 +120,11 @@ namespace Tensorflow
}); });
} }


public Tensor this[params Tensor[] slice]
public Tensor this[Tensor start, Tensor stop = null, Tensor step = null]
{ {
get get
{ {
var args = tensor_util.ParseSlices(slice);
var args = tensor_util.ParseSlices(start, stop: stop, step: step);


return tf_with(ops.name_scope(null, "strided_slice", args), scope => return tf_with(ops.name_scope(null, "strided_slice", args), scope =>
{ {


+ 15
- 10
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -674,25 +674,30 @@ would not be rank 1.", tensor.op.get_attr("axis")));
}; };
} }


public static ParsedSliceArgs ParseSlices(Tensor[] slices)
public static ParsedSliceArgs ParseSlices(Tensor start, Tensor stop = null, Tensor step = null)
{ {
var begin = new List<Tensor>(); var begin = new List<Tensor>();
var end = new List<Tensor>(); var end = new List<Tensor>();
var strides = new List<Tensor>(); var strides = new List<Tensor>();


var index = 0;
// var index = 0;
var (new_axis_mask, shrink_axis_mask) = (0, 0); var (new_axis_mask, shrink_axis_mask) = (0, 0);
var (begin_mask, end_mask) = (0, 0); var (begin_mask, end_mask) = (0, 0);
var ellipsis_mask = 0; var ellipsis_mask = 0;


foreach (var s in slices)
{
begin.Add(s);
end.Add(s + 1);
shrink_axis_mask |= (1 << index);
strides.Add(tf.constant(1, dtype: s.dtype));
index += 1;
}
begin.Add(start);

if (stop == null)
end.Add(start + 1);
else
end.Add(stop);

// shrink_axis_mask |= (1 << index);

if (step == null)
strides.Add(tf.constant(1, dtype: start.dtype));
else
strides.Add(step);


return new ParsedSliceArgs return new ParsedSliceArgs
{ {


+ 3
- 0
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -40,6 +40,9 @@ namespace Tensorflow
public Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) public Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.zeros(shape, dtype, name); => array_ops.zeros(shape, dtype, name);


public Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.zeros(shape, dtype, name);

public Tensor ones(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) public Tensor ones(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
=> array_ops.ones(shape, dtype, name); => array_ops.ones(shape, dtype, name);




Loading…
Cancel
Save