From 128e56d8bb00f60d7a9e856023a04011df46f313 Mon Sep 17 00:00:00 2001 From: haiping008 Date: Fri, 29 Mar 2019 13:20:03 -0500 Subject: [PATCH] add check_ops.assert_positive, check_ops.assert_less --- src/TensorFlowNET.Core/Clustering/KMeans.cs | 4 +- .../_InitializeClustersOpFactory.cs | 26 +++++++-- .../Operations/check_ops.cs | 53 ++++++++++++++++++- .../Operations/control_flow_ops.py.cs | 2 +- 4 files changed, 77 insertions(+), 8 deletions(-) diff --git a/src/TensorFlowNET.Core/Clustering/KMeans.cs b/src/TensorFlowNET.Core/Clustering/KMeans.cs index cb0da4c6..783a8b71 100644 --- a/src/TensorFlowNET.Core/Clustering/KMeans.cs +++ b/src/TensorFlowNET.Core/Clustering/KMeans.cs @@ -19,7 +19,7 @@ namespace Tensorflow.Clustering Tensor[] _inputs; int _num_clusters; - IInitializer _initial_clusters; + string _initial_clusters; string _distance_metric; bool _use_mini_batch; int _mini_batch_steps_per_iteration; @@ -29,7 +29,7 @@ namespace Tensorflow.Clustering public KMeans(Tensor inputs, int num_clusters, - IInitializer initial_clusters = null, + string initial_clusters = RANDOM_INIT, string distance_metric = SQUARED_EUCLIDEAN_DISTANCE, bool use_mini_batch = false, int mini_batch_steps_per_iteration = 1, diff --git a/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs index 0daafdfe..14393708 100644 --- a/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs +++ b/src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs @@ -8,11 +8,11 @@ namespace Tensorflow.Clustering /// /// Internal class to create the op to initialize the clusters. /// - public class _InitializeClustersOpFactory + public class _InitializeClustersOpFactory : Python { Tensor[] _inputs; Tensor _num_clusters; - IInitializer _initial_clusters; + string _initial_clusters; string _distance_metric; int _random_seed; int _kmeans_plus_plus_num_retries; @@ -26,7 +26,7 @@ namespace Tensorflow.Clustering public _InitializeClustersOpFactory(Tensor[] inputs, Tensor num_clusters, - IInitializer initial_clusters, + string initial_clusters, string distance_metric, int random_seed, int kmeans_plus_plus_num_retries, @@ -56,7 +56,25 @@ namespace Tensorflow.Clustering { return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0), () => new Operation[] { check_ops.assert_equal(_cluster_centers_initialized, true) }, - () => new Operation[0]); + _initialize); } + + private Operation[] _initialize() + { + with(ops.control_dependencies(new Operation[] + { + check_ops.assert_positive(_num_remaining) + }), delegate + { + // var num_now_remaining = _add_new_centers(); + }); + + throw new NotImplementedException("_InitializeClustersOpFactory _initialize"); + } + + /*private int _add_new_centers() + { + var new_centers = _choose_initial_centers(); + }*/ } } diff --git a/src/TensorFlowNET.Core/Operations/check_ops.cs b/src/TensorFlowNET.Core/Operations/check_ops.cs index bad058e9..169cb333 100644 --- a/src/TensorFlowNET.Core/Operations/check_ops.cs +++ b/src/TensorFlowNET.Core/Operations/check_ops.cs @@ -21,7 +21,58 @@ namespace Tensorflow var condition = math_ops.reduce_all(gen_math_ops.equal(x, y)); var x_static = tensor_util.constant_value(x); var y_static = tensor_util.constant_value(y); - return control_flow_ops.Asset(condition, data); + return control_flow_ops.Assert(condition, data); + }); + } + + public static Operation assert_positive(Tensor x, object[] data = null, string message = null, string name = null) + { + if (message == null) + message = ""; + + return with(ops.name_scope(name, "assert_positive", new { x, data }), delegate + { + x = ops.convert_to_tensor(x, name: "x"); + if (data == null) + { + name = x.name; + data = new object[] + { + message, + "Condition x > 0 did not hold element-wise:", + $"x (%s) = {name}", + x + }; + } + var zero = ops.convert_to_tensor(0, dtype: x.dtype); + return assert_less(zero, x, data: data); + }); + } + + public static Operation assert_less(Tensor x, Tensor y, object[] data = null, string message = null, string name = null) + { + if (message == null) + message = ""; + + return with(ops.name_scope(name, "assert_less", new { x, y, data }), delegate + { + x = ops.convert_to_tensor(x, name: "x"); + y = ops.convert_to_tensor(y, name: "y"); + string x_name = x.name; + string y_name = y.name; + if (data == null) + { + data = new object[] + { + message, + "Condition x < y did not hold element-wise:", + $"x (%s) = {x_name}", + $"y (%s) = {y_name}", + y + }; + } + var condition = math_ops.reduce_all(gen_math_ops.less(x, y)); + return control_flow_ops.Assert(condition, data); }); } } diff --git a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs index 95317afb..79416930 100644 --- a/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs +++ b/src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs @@ -9,7 +9,7 @@ namespace Tensorflow { public class control_flow_ops : Python { - public static Operation Asset(Tensor condition, object[] data, int? summarize = null, string name = null) + public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null) { return with(ops.name_scope(name, "Assert", new { condition, data }), scope => {