|
|
@@ -14,10 +14,15 @@ |
|
|
limitations under the License. |
|
|
limitations under the License. |
|
|
******************************************************************************/ |
|
|
******************************************************************************/ |
|
|
|
|
|
|
|
|
|
|
|
using Serilog.Debugging; |
|
|
using System; |
|
|
using System; |
|
|
|
|
|
using System.Collections.Concurrent; |
|
|
using System.Collections.Generic; |
|
|
using System.Collections.Generic; |
|
|
|
|
|
//using System.ComponentModel.DataAnnotations; |
|
|
using System.Text; |
|
|
using System.Text; |
|
|
|
|
|
using System.Xml.Linq; |
|
|
using Tensorflow.Framework; |
|
|
using Tensorflow.Framework; |
|
|
|
|
|
using Tensorflow.NumPy; |
|
|
using static Tensorflow.Binding; |
|
|
using static Tensorflow.Binding; |
|
|
|
|
|
|
|
|
namespace Tensorflow |
|
|
namespace Tensorflow |
|
|
@@ -99,5 +104,55 @@ namespace Tensorflow |
|
|
return new RowPartition(row_splits); |
|
|
return new RowPartition(row_splits); |
|
|
}); |
|
|
}); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public static RowPartition from_row_lengths(Tensor row_lengths, |
|
|
|
|
|
bool validate=true, |
|
|
|
|
|
TF_DataType dtype = TF_DataType.TF_INT32, |
|
|
|
|
|
TF_DataType dtype_hint= TF_DataType.TF_INT32) |
|
|
|
|
|
{ |
|
|
|
|
|
row_lengths = _convert_row_partition( |
|
|
|
|
|
row_lengths, "row_lengths", dtype_hint: dtype_hint, dtype: dtype); |
|
|
|
|
|
Tensor row_limits = math_ops.cumsum<Tensor>(row_lengths, tf.constant(-1)); |
|
|
|
|
|
Tensor row_splits = array_ops.concat(new Tensor[] { tf.convert_to_tensor(np.array(new int[] { 0 }, TF_DataType.TF_INT64)), row_limits }, axis:0); |
|
|
|
|
|
return new RowPartition(row_splits: row_splits, row_lengths: row_lengths); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public static Tensor _convert_row_partition(Tensor partition, string name, TF_DataType dtype, |
|
|
|
|
|
TF_DataType dtype_hint= TF_DataType.TF_INT64) |
|
|
|
|
|
{ |
|
|
|
|
|
if (partition is NDArray && partition.GetDataType() == np.int32) partition = ops.convert_to_tensor(partition, name: name); |
|
|
|
|
|
if (partition.GetDataType() != np.int32 && partition.GetDataType() != np.int64) throw new ValueError($"{name} must have dtype int32 or int64"); |
|
|
|
|
|
return partition; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public Tensor nrows() |
|
|
|
|
|
{ |
|
|
|
|
|
/*Returns the number of rows created by this `RowPartition*/ |
|
|
|
|
|
if (this._nrows != null) return this._nrows; |
|
|
|
|
|
var nsplits = tensor_shape.dimension_at_index(this._row_splits.shape, 0); |
|
|
|
|
|
if (nsplits == null) return array_ops.shape(this._row_splits, out_type: this.row_splits.dtype)[0] - 1; |
|
|
|
|
|
else return constant_op.constant(nsplits.value - 1, dtype: this.row_splits.dtype); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
public Tensor row_lengths() |
|
|
|
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
if (this._row_splits != null) |
|
|
|
|
|
{ |
|
|
|
|
|
int nrows_plus_one = tensor_shape.dimension_value(this._row_splits.shape[0]); |
|
|
|
|
|
return tf.constant(nrows_plus_one - 1); |
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
if (this._row_lengths != null) |
|
|
|
|
|
{ |
|
|
|
|
|
var nrows = tensor_shape.dimension_value(this._row_lengths.shape[0]); |
|
|
|
|
|
return tf.constant(nrows); |
|
|
|
|
|
} |
|
|
|
|
|
if(this._nrows != null) |
|
|
|
|
|
{ |
|
|
|
|
|
return tensor_util.constant_value(this._nrows); |
|
|
|
|
|
} |
|
|
|
|
|
return tf.constant(-1); |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} |