| @@ -1139,5 +1139,18 @@ namespace Tensorflow | |||
| var _op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape }); | |||
| return _op.output; | |||
| } | |||
| public static int get_positive_axis(int axis, int ndims=-100, string axis_name="axis", string ndims_name= "ndims") | |||
| { | |||
| if(ndims != -100) | |||
| { | |||
| if (axis >= 0 && axis < ndims) return axis; | |||
| else if (-ndims <= axis && axis < 0) return axis + ndims; | |||
| else throw new ValueError($"{axis_name}={axis} out of bounds:expected {-ndims}<={axis_name}<{ndims}"); | |||
| } else if(axis < 0) throw new ValueError($"{axis_name}={axis} may only be negative if {ndims_name} is statically known."); | |||
| return axis; | |||
| } | |||
| } | |||
| } | |||
| @@ -163,5 +163,38 @@ namespace Tensorflow | |||
| { | |||
| return tensor.Tag as RaggedTensor; | |||
| } | |||
| public Tensor nrows(TF_DataType out_type, string name = null) | |||
| { | |||
| tf_with(ops.name_scope(name, "RaggedNRows"), scope => | |||
| { | |||
| return math_ops.cast(this._row_partition.nrows(), dtype: out_type); | |||
| }); | |||
| return null; | |||
| } | |||
| public RaggedTensor row_lengths(int axis=-1, string name=null) | |||
| { | |||
| if (axis == 0) return this._row_partition.nrows(); | |||
| if (axis == 1) return this._row_partition.row_lengths(); | |||
| var values = (RaggedTensor)this._values; | |||
| axis = array_ops.get_positive_axis( | |||
| axis, this.shape.rank, ndims_name: "rank(this)"); | |||
| if (axis == 0) return this.nrows(this._row_partition.GetDataType()); | |||
| else if (axis == 1) | |||
| { | |||
| var splits = this._row_partition.row_splits; | |||
| return splits[new Slice(start: 1)] - splits[new Slice(stop: -1)]; | |||
| } | |||
| else if (this._values is RaggedTensor) | |||
| { | |||
| return values.row_lengths(axis - 1); | |||
| } | |||
| else | |||
| { | |||
| var shape = array_ops.shape(values, out_type: this._row_partition.GetDataType()); | |||
| return array_ops.ones(shape[new Slice(stop:axis - 1)], this._row_partition.GetDataType()) * | |||
| shape[axis - 1]; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -14,10 +14,15 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using Serilog.Debugging; | |||
| using System; | |||
| using System.Collections.Concurrent; | |||
| using System.Collections.Generic; | |||
| //using System.ComponentModel.DataAnnotations; | |||
| using System.Text; | |||
| using System.Xml.Linq; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -99,5 +104,55 @@ namespace Tensorflow | |||
| return new RowPartition(row_splits); | |||
| }); | |||
| } | |||
| public static RowPartition from_row_lengths(Tensor row_lengths, | |||
| bool validate=true, | |||
| TF_DataType dtype = TF_DataType.TF_INT32, | |||
| TF_DataType dtype_hint= TF_DataType.TF_INT32) | |||
| { | |||
| row_lengths = _convert_row_partition( | |||
| row_lengths, "row_lengths", dtype_hint: dtype_hint, dtype: dtype); | |||
| Tensor row_limits = math_ops.cumsum<Tensor>(row_lengths, tf.constant(-1)); | |||
| Tensor row_splits = array_ops.concat(new Tensor[] { tf.convert_to_tensor(np.array(new int[] { 0 }, TF_DataType.TF_INT64)), row_limits }, axis:0); | |||
| return new RowPartition(row_splits: row_splits, row_lengths: row_lengths); | |||
| } | |||
| public static Tensor _convert_row_partition(Tensor partition, string name, TF_DataType dtype, | |||
| TF_DataType dtype_hint= TF_DataType.TF_INT64) | |||
| { | |||
| if (partition is NDArray && partition.GetDataType() == np.int32) partition = ops.convert_to_tensor(partition, name: name); | |||
| if (partition.GetDataType() != np.int32 && partition.GetDataType() != np.int64) throw new ValueError($"{name} must have dtype int32 or int64"); | |||
| return partition; | |||
| } | |||
| public Tensor nrows() | |||
| { | |||
| /*Returns the number of rows created by this `RowPartition*/ | |||
| if (this._nrows != null) return this._nrows; | |||
| var nsplits = tensor_shape.dimension_at_index(this._row_splits.shape, 0); | |||
| if (nsplits == null) return array_ops.shape(this._row_splits, out_type: this.row_splits.dtype)[0] - 1; | |||
| else return constant_op.constant(nsplits.value - 1, dtype: this.row_splits.dtype); | |||
| } | |||
| public Tensor row_lengths() | |||
| { | |||
| if (this._row_splits != null) | |||
| { | |||
| int nrows_plus_one = tensor_shape.dimension_value(this._row_splits.shape[0]); | |||
| return tf.constant(nrows_plus_one - 1); | |||
| } | |||
| if (this._row_lengths != null) | |||
| { | |||
| var nrows = tensor_shape.dimension_value(this._row_lengths.shape[0]); | |||
| return tf.constant(nrows); | |||
| } | |||
| if(this._nrows != null) | |||
| { | |||
| return tensor_util.constant_value(this._nrows); | |||
| } | |||
| return tf.constant(-1); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.NumPy; | |||
| using static Tensorflow.Binding; | |||
| namespace TensorFlowNET.UnitTest.ManagedAPI | |||
| { | |||
| public class RaggedTensorTest :EagerModeTestBase | |||
| { | |||
| [TestMethod] | |||
| public void Test_from_row_lengths() | |||
| { | |||
| var row_lengths = tf.convert_to_tensor(np.array(new int[] { 2, 0, 3, 1, 1 }, TF_DataType.TF_INT64)); | |||
| var rp = RowPartition.from_row_lengths(row_lengths, validate: false); | |||
| var rp_row_lengths = rp.row_lengths(); | |||
| var rp_nrows = rp.nrows(); | |||
| Assert.IsTrue(rp_nrows.ToArray<long>()[0] == rp.nrows().ToArray<long>()[0]); | |||
| } | |||
| } | |||
| } | |||