diff --git a/src/TensorFlowNET.Core/Clustering/KMeans.cs b/src/TensorFlowNET.Core/Clustering/KMeans.cs
index ac945cb2..cb0da4c6 100644
--- a/src/TensorFlowNET.Core/Clustering/KMeans.cs
+++ b/src/TensorFlowNET.Core/Clustering/KMeans.cs
@@ -53,7 +53,17 @@ namespace Tensorflow.Clustering
var initial_clusters = _initial_clusters;
var num_clusters = ops.convert_to_tensor(_num_clusters);
var inputs = _inputs;
- _create_variables(num_clusters);
+ var vars = _create_variables(num_clusters);
+ var cluster_centers_var = vars[0];
+ var cluster_centers_initialized = vars[1];
+ var total_counts = vars[2];
+ var cluster_centers_updated = vars[3];
+ var update_in_steps = vars[4];
+
+ var init_op = new _InitializeClustersOpFactory(_inputs, num_clusters, initial_clusters, _distance_metric,
+ _random_seed, _kmeans_plus_plus_num_retries,
+ _kmc2_chain_length, cluster_centers_var, cluster_centers_updated,
+ cluster_centers_initialized).op();
throw new NotImplementedException("KMeans training_graph");
}
diff --git a/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs
new file mode 100644
index 00000000..0daafdfe
--- /dev/null
+++ b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs
@@ -0,0 +1,62 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+
+namespace Tensorflow.Clustering
+{
+ ///
+ /// Internal class to create the op to initialize the clusters.
+ ///
+ public class _InitializeClustersOpFactory
+ {
+ Tensor[] _inputs;
+ Tensor _num_clusters;
+ IInitializer _initial_clusters;
+ string _distance_metric;
+ int _random_seed;
+ int _kmeans_plus_plus_num_retries;
+ int _kmc2_chain_length;
+ RefVariable _cluster_centers;
+ RefVariable _cluster_centers_updated;
+ RefVariable _cluster_centers_initialized;
+ Tensor _num_selected;
+ Tensor _num_remaining;
+ Tensor _num_data;
+
+ public _InitializeClustersOpFactory(Tensor[] inputs,
+ Tensor num_clusters,
+ IInitializer initial_clusters,
+ string distance_metric,
+ int random_seed,
+ int kmeans_plus_plus_num_retries,
+ int kmc2_chain_length,
+ RefVariable cluster_centers,
+ RefVariable cluster_centers_updated,
+ RefVariable cluster_centers_initialized)
+ {
+ _inputs = inputs;
+ _num_clusters = num_clusters;
+ _initial_clusters = initial_clusters;
+ _distance_metric = distance_metric;
+ _random_seed = random_seed;
+ _kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries;
+ _kmc2_chain_length = kmc2_chain_length;
+ _cluster_centers = cluster_centers;
+ _cluster_centers_updated = cluster_centers_updated;
+ _cluster_centers_initialized = cluster_centers_initialized;
+
+ _num_selected = array_ops.shape(_cluster_centers)[0];
+ _num_remaining = _num_clusters - _num_selected;
+
+ _num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray());
+ }
+
+ public Tensor[] op()
+ {
+ return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0),
+ () => new Operation[] { check_ops.assert_equal(_cluster_centers_initialized, true) },
+ () => new Operation[0]);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/check_ops.cs b/src/TensorFlowNET.Core/Operations/check_ops.cs
new file mode 100644
index 00000000..bad058e9
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/check_ops.cs
@@ -0,0 +1,28 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class check_ops : Python
+ {
+ ///
+ /// Assert the condition `x == y` holds element-wise.
+ ///
+ ///
+ ///
+ ///
+ public static Operation assert_equal(object t1, object t2, object[] data = null, string name = null)
+ {
+ return with(ops.name_scope(name, "assert_equal", new { t1, t2, data }), delegate
+ {
+ var x = ops.convert_to_tensor(t1, name: "x");
+ var y = ops.convert_to_tensor(t2, name: "y");
+ var condition = math_ops.reduce_all(gen_math_ops.equal(x, y));
+ var x_static = tensor_util.constant_value(x);
+ var y_static = tensor_util.constant_value(y);
+ return control_flow_ops.Asset(condition, data);
+ });
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
index 03773c9f..95317afb 100644
--- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs
@@ -9,6 +9,29 @@ namespace Tensorflow
{
public class control_flow_ops : Python
{
+ public static Operation Asset(Tensor condition, object[] data, int? summarize = null, string name = null)
+ {
+ return with(ops.name_scope(name, "Assert", new { condition, data }), scope =>
+ {
+ name = scope;
+ var xs = ops.convert_n_to_tensor(data);
+ condition = ops.convert_to_tensor(condition, name: "Condition");
+ Func true_assert = () => new Operation[]
+ {
+ gen_logging_ops._assert(condition, data, summarize, name: "Assert")
+ };
+
+ Func false_assert = () => new Operation[]
+ {
+ gen_control_flow_ops.no_op()
+ };
+
+ var guarded_assert = cond(condition, false_assert, true_assert, name: "AssertGuard");
+
+ return guarded_assert[0].op;
+ });
+ }
+
public static Operation group(T[] inputs, string name = null) where T : ITensorOrOperation
{
return with(ops.name_scope(name, "group_deps", inputs), scope =>
diff --git a/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs b/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs
new file mode 100644
index 00000000..6f11dd0e
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/gen_logging_ops.cs
@@ -0,0 +1,21 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ public class gen_logging_ops
+ {
+ public static OpDefLibrary _op_def_lib = new OpDefLibrary();
+
+ public static Operation _assert(Tensor condition, object[] data, int? summarize = 3, string name = null)
+ {
+ if (!summarize.HasValue)
+ summarize = 3;
+
+ var _op = _op_def_lib._apply_op_helper("Assert", name, args: new { condition, data, summarize });
+
+ return _op;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index a3f017bf..df1f2d88 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -10,6 +10,13 @@ namespace Tensorflow
{
public static OpDefLibrary _op_def_lib = new OpDefLibrary();
+ public static Tensor _all(Tensor input, Tensor axis, bool keep_dims = false, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("All", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims });
+
+ return _op.outputs[0];
+ }
+
///
/// Returns the index with the largest value across dimensions of a tensor.
///
@@ -250,7 +257,7 @@ namespace Tensorflow
///
///
///
- public static Tensor equal(Tensor x, Tensor y, string name = null)
+ public static Tensor equal(Tx x, Ty y, string name = null)
{
var _op = _op_def_lib._apply_op_helper("Equal", name, args: new { x, y });
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 8077a87a..6d07376a 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -37,6 +37,27 @@ namespace Tensorflow
});
}
+ ///
+ /// Adds all input tensors element-wise.
+ ///
+ ///
+ ///
+ ///
+ public static Tensor add_n(Tensor[] inputs, string name = null)
+ {
+ inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs);
+
+ if(inputs.Length == 1)
+ {
+ var values = inputs[0];
+ if (name != null)
+ return array_ops.identity(values, name: name);
+ return values;
+ }
+ throw new NotImplementedException("math_ops add_n n > 1");
+ // return gen_math_ops.add_n(inputs, name: name);
+ }
+
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
{
var base_type = dtype.as_base_dtype();
@@ -161,7 +182,24 @@ namespace Tensorflow
///
public static Tensor reciprocal(Tensor x, string name = null)
=> gen_math_ops.reciprocal(x, name: name);
-
+
+ ///
+ /// Computes the "logical and" of elements across dimensions of a tensor.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor reduce_all(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
+ {
+ var all = gen_math_ops._all(input_tensor,
+ _ReductionDims(input_tensor, axis),
+ keepdims,
+ name: name);
+
+ return _may_reduce_to_scalar(keepdims, axis, all);
+ }
///
/// Computes log(sum(exp(elements across dimensions of a tensor))).
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index 602e0137..b2de43fb 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -346,20 +346,17 @@ namespace Tensorflow
session.run(operation, feed_dict);
}
+ public static Tensor[] convert_n_to_tensor(object[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
+ => internal_convert_n_to_tensor(values, dtype: dtype, name: name, as_ref: false);
+
public static Tensor[] convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
- {
- return internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name);
- }
+ => internal_convert_n_to_tensor_or_indexed_slices(values, dtype: dtype, name: name);
public static Tensor convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
- {
- return internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false);
- }
+ => internal_convert_to_tensor_or_indexed_slices(value: value, dtype: dtype, name: name, as_ref: false);
public static Tensor internal_convert_to_tensor_or_indexed_slices(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
- {
- return value;
- }
+ => value;
public static Tensor[] internal_convert_n_to_tensor_or_indexed_slices(Tensor[] values, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false)
{
diff --git a/test/TensorFlowNET.Examples/KMeansClustering.cs b/test/TensorFlowNET.Examples/KMeansClustering.cs
index 1305fef5..fcec5efa 100644
--- a/test/TensorFlowNET.Examples/KMeansClustering.cs
+++ b/test/TensorFlowNET.Examples/KMeansClustering.cs
@@ -39,7 +39,7 @@ namespace TensorFlowNET.Examples
// Build KMeans graph
var training_graph = kmeans.training_graph();
-
+
return false;
}