Browse Source

add check_ops.assert_positive, check_ops.assert_less

tags/v0.9
haiping008 6 years ago
parent
commit
128e56d8bb
4 changed files with 77 additions and 8 deletions
  1. +2
    -2
      src/TensorFlowNET.Core/Clustering/KMeans.cs
  2. +22
    -4
      src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs
  3. +52
    -1
      src/TensorFlowNET.Core/Operations/check_ops.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs

+ 2
- 2
src/TensorFlowNET.Core/Clustering/KMeans.cs View File

@@ -19,7 +19,7 @@ namespace Tensorflow.Clustering

Tensor[] _inputs;
int _num_clusters;
IInitializer _initial_clusters;
string _initial_clusters;
string _distance_metric;
bool _use_mini_batch;
int _mini_batch_steps_per_iteration;
@@ -29,7 +29,7 @@ namespace Tensorflow.Clustering

public KMeans(Tensor inputs,
int num_clusters,
IInitializer initial_clusters = null,
string initial_clusters = RANDOM_INIT,
string distance_metric = SQUARED_EUCLIDEAN_DISTANCE,
bool use_mini_batch = false,
int mini_batch_steps_per_iteration = 1,


+ 22
- 4
src/TensorFlowNET.Core/Clustering/_InitializeClustersOpFactory.cs View File

@@ -8,11 +8,11 @@ namespace Tensorflow.Clustering
/// <summary>
/// Internal class to create the op to initialize the clusters.
/// </summary>
public class _InitializeClustersOpFactory
public class _InitializeClustersOpFactory : Python
{
Tensor[] _inputs;
Tensor _num_clusters;
IInitializer _initial_clusters;
string _initial_clusters;
string _distance_metric;
int _random_seed;
int _kmeans_plus_plus_num_retries;
@@ -26,7 +26,7 @@ namespace Tensorflow.Clustering

public _InitializeClustersOpFactory(Tensor[] inputs,
Tensor num_clusters,
IInitializer initial_clusters,
string initial_clusters,
string distance_metric,
int random_seed,
int kmeans_plus_plus_num_retries,
@@ -56,7 +56,25 @@ namespace Tensorflow.Clustering
{
return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0),
() => new Operation[] { check_ops.assert_equal(_cluster_centers_initialized, true) },
() => new Operation[0]);
_initialize);
}

private Operation[] _initialize()
{
with(ops.control_dependencies(new Operation[]
{
check_ops.assert_positive(_num_remaining)
}), delegate
{
// var num_now_remaining = _add_new_centers();
});

throw new NotImplementedException("_InitializeClustersOpFactory _initialize");
}

/*private int _add_new_centers()
{
var new_centers = _choose_initial_centers();
}*/
}
}

+ 52
- 1
src/TensorFlowNET.Core/Operations/check_ops.cs View File

@@ -21,7 +21,58 @@ namespace Tensorflow
var condition = math_ops.reduce_all(gen_math_ops.equal(x, y));
var x_static = tensor_util.constant_value(x);
var y_static = tensor_util.constant_value(y);
return control_flow_ops.Asset(condition, data);
return control_flow_ops.Assert(condition, data);
});
}

public static Operation assert_positive(Tensor x, object[] data = null, string message = null, string name = null)
{
if (message == null)
message = "";

return with(ops.name_scope(name, "assert_positive", new { x, data }), delegate
{
x = ops.convert_to_tensor(x, name: "x");
if (data == null)
{
name = x.name;
data = new object[]
{
message,
"Condition x > 0 did not hold element-wise:",
$"x (%s) = {name}",
x
};
}
var zero = ops.convert_to_tensor(0, dtype: x.dtype);
return assert_less(zero, x, data: data);
});
}

public static Operation assert_less(Tensor x, Tensor y, object[] data = null, string message = null, string name = null)
{
if (message == null)
message = "";

return with(ops.name_scope(name, "assert_less", new { x, y, data }), delegate
{
x = ops.convert_to_tensor(x, name: "x");
y = ops.convert_to_tensor(y, name: "y");
string x_name = x.name;
string y_name = y.name;
if (data == null)
{
data = new object[]
{
message,
"Condition x < y did not hold element-wise:",
$"x (%s) = {x_name}",
$"y (%s) = {y_name}",
y
};
}
var condition = math_ops.reduce_all(gen_math_ops.less(x, y));
return control_flow_ops.Assert(condition, data);
});
}
}


+ 1
- 1
src/TensorFlowNET.Core/Operations/control_flow_ops.py.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow
{
public class control_flow_ops : Python
{
public static Operation Asset(Tensor condition, object[] data, int? summarize = null, string name = null)
public static Operation Assert(Tensor condition, object[] data, int? summarize = null, string name = null)
{
return with(ops.name_scope(name, "Assert", new { condition, data }), scope =>
{


Loading…
Cancel
Save