diff --git a/src/TensorFlowNET.Core/APIs/tf.data_flow.cs b/src/TensorFlowNET.Core/APIs/tf.data_flow.cs index 593596ff..3ea6a70d 100644 --- a/src/TensorFlowNET.Core/APIs/tf.data_flow.cs +++ b/src/TensorFlowNET.Core/APIs/tf.data_flow.cs @@ -27,7 +27,19 @@ namespace Tensorflow /// /// /// - public static Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null) + public Tensor dynamic_stitch(Tensor[] indices, Tensor[] data, string name = null) => gen_data_flow_ops.dynamic_stitch(indices, data, name: name); + + /// + /// Partitions `data` into `num_partitions` tensors using indices from `partitions`. + /// + /// + /// + /// The number of partitions to output. + /// + /// + public Tensor[] dynamic_partition(Tensor data, Tensor partitions, int num_partitions, + string name = null) + => gen_data_flow_ops.dynamic_partition(data, partitions, num_partitions, name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs index 65b86f04..37ae486e 100644 --- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs +++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs @@ -27,6 +27,19 @@ namespace Tensorflow return _op.output; } + public static Tensor[] dynamic_partition(Tensor data, Tensor partitions, int num_partitions, + string name = null) + { + var _op = _op_def_lib._apply_op_helper("DynamicPartition", name, new + { + data, + partitions, + num_partitions + }); + + return _op.outputs; + } + public static (Tensor, Tensor) tensor_array_v3(T size, TF_DataType dtype = TF_DataType.DtInvalid, TensorShape element_shape = null, bool dynamic_size = false, bool clear_after_read = true, bool identical_element_shapes = false, string tensor_array_name = "", string name = null)