| @@ -191,14 +191,13 @@ namespace Tensorflow | |||||
| private static Tensor _constant_if_small<T>(T value, TensorShape shape, TF_DataType dtype, string name) | private static Tensor _constant_if_small<T>(T value, TensorShape shape, TF_DataType dtype, string name) | ||||
| { | { | ||||
| Tensor shape_t = null; | |||||
| if (shape.size < 1000) | if (shape.size < 1000) | ||||
| { | { | ||||
| return constant_op.constant(value, shape: shape, dtype: dtype, name: name); | return constant_op.constant(value, shape: shape, dtype: dtype, name: name); | ||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| shape_t = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
| var shape_t = constant_op._tensor_shape_tensor_conversion_function(shape); | |||||
| var c = constant_op.constant(0, dtype: dtype); | var c = constant_op.constant(0, dtype: dtype); | ||||
| return gen_array_ops.fill(shape_t, c, name: name); | return gen_array_ops.fill(shape_t, c, name: name); | ||||
| } | } | ||||
| @@ -80,10 +80,11 @@ namespace Tensorflow | |||||
| var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING); | var sep_tensor = ops.convert_to_tensor(sep, dtype: TF_DataType.TF_STRING); | ||||
| if(input.rank == 0) | if(input.rank == 0) | ||||
| { | { | ||||
| return string_split_v2(array_ops.stack(new[] { input }), | |||||
| var parts = string_split_v2(array_ops.stack(new[] { input }), | |||||
| sep: sep, | sep: sep, | ||||
| maxsplit: maxsplit, | maxsplit: maxsplit, | ||||
| name: name)[0]; | |||||
| name: name); | |||||
| return parts; | |||||
| } | } | ||||
| var result = tf.Context.ExecuteOp("StringSplitV2", name, | var result = tf.Context.ExecuteOp("StringSplitV2", name, | ||||
| @@ -44,6 +44,18 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| public Tensor this[int index] | |||||
| { | |||||
| get | |||||
| { | |||||
| return tf_with(ops.name_scope(null, "RaggedGetItem"), scope => | |||||
| { | |||||
| string name = scope; | |||||
| return _ragged_getitem(index); | |||||
| }); | |||||
| } | |||||
| } | |||||
| public RaggedTensor this[params Slice[] slices] | public RaggedTensor this[params Slice[] slices] | ||||
| { | { | ||||
| get | get | ||||
| @@ -61,6 +73,14 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| Tensor _ragged_getitem(int row_key) | |||||
| { | |||||
| var starts = _row_splits[":-1"]; | |||||
| var limits = _row_splits["1:"]; | |||||
| var row = _values[starts[row_key], limits[row_key]]; | |||||
| return row; | |||||
| } | |||||
| RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices) | RaggedTensor _ragged_getitem_inner_dimensions(RaggedTensor input, Slice[] slices) | ||||
| { | { | ||||
| return input; | return input; | ||||
| @@ -134,7 +154,7 @@ namespace Tensorflow | |||||
| => new[] { _row_splits }; | => new[] { _row_splits }; | ||||
| public override string ToString() | public override string ToString() | ||||
| => $"tf.RaggedTensor: shape={_values.TensorShape} [{string.Join(", ", _values.StringData().Take(10))}]"; | |||||
| => $"tf.RaggedTensor: shape={shape} [{string.Join(", ", _values.StringData().Take(10))}]"; | |||||
| public static implicit operator Tensor(RaggedTensor indexedSlices) | public static implicit operator Tensor(RaggedTensor indexedSlices) | ||||
| => indexedSlices._to_variant(); | => indexedSlices._to_variant(); | ||||
| @@ -120,11 +120,11 @@ namespace Tensorflow | |||||
| }); | }); | ||||
| } | } | ||||
| public Tensor this[params Tensor[] slice] | |||||
| public Tensor this[Tensor start, Tensor stop = null, Tensor step = null] | |||||
| { | { | ||||
| get | get | ||||
| { | { | ||||
| var args = tensor_util.ParseSlices(slice); | |||||
| var args = tensor_util.ParseSlices(start, stop: stop, step: step); | |||||
| return tf_with(ops.name_scope(null, "strided_slice", args), scope => | return tf_with(ops.name_scope(null, "strided_slice", args), scope => | ||||
| { | { | ||||
| @@ -674,25 +674,30 @@ would not be rank 1.", tensor.op.get_attr("axis"))); | |||||
| }; | }; | ||||
| } | } | ||||
| public static ParsedSliceArgs ParseSlices(Tensor[] slices) | |||||
| public static ParsedSliceArgs ParseSlices(Tensor start, Tensor stop = null, Tensor step = null) | |||||
| { | { | ||||
| var begin = new List<Tensor>(); | var begin = new List<Tensor>(); | ||||
| var end = new List<Tensor>(); | var end = new List<Tensor>(); | ||||
| var strides = new List<Tensor>(); | var strides = new List<Tensor>(); | ||||
| var index = 0; | |||||
| // var index = 0; | |||||
| var (new_axis_mask, shrink_axis_mask) = (0, 0); | var (new_axis_mask, shrink_axis_mask) = (0, 0); | ||||
| var (begin_mask, end_mask) = (0, 0); | var (begin_mask, end_mask) = (0, 0); | ||||
| var ellipsis_mask = 0; | var ellipsis_mask = 0; | ||||
| foreach (var s in slices) | |||||
| { | |||||
| begin.Add(s); | |||||
| end.Add(s + 1); | |||||
| shrink_axis_mask |= (1 << index); | |||||
| strides.Add(tf.constant(1, dtype: s.dtype)); | |||||
| index += 1; | |||||
| } | |||||
| begin.Add(start); | |||||
| if (stop == null) | |||||
| end.Add(start + 1); | |||||
| else | |||||
| end.Add(stop); | |||||
| // shrink_axis_mask |= (1 << index); | |||||
| if (step == null) | |||||
| strides.Add(tf.constant(1, dtype: start.dtype)); | |||||
| else | |||||
| strides.Add(step); | |||||
| return new ParsedSliceArgs | return new ParsedSliceArgs | ||||
| { | { | ||||
| @@ -40,6 +40,9 @@ namespace Tensorflow | |||||
| public Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public Tensor zeros(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
| => array_ops.zeros(shape, dtype, name); | => array_ops.zeros(shape, dtype, name); | ||||
| public Tensor zeros(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||||
| => array_ops.zeros(shape, dtype, name); | |||||
| public Tensor ones(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public Tensor ones(TensorShape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
| => array_ops.ones(shape, dtype, name); | => array_ops.ones(shape, dtype, name); | ||||