diff --git a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs index a10a099a..11a89af1 100644 --- a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs +++ b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Text; +using static Tensorflow.OpDef.Types; namespace Tensorflow { @@ -19,9 +20,32 @@ namespace Tensorflow foreach (var op_def in op_list.Op) _registered_ops[op_def.Name] = op_def; + + if (!_registered_ops.ContainsKey("NearestNeighbors")) + _registered_ops["NearestNeighbors"] = op_NearestNeighbors(); } return _registered_ops; } + + /// + /// Doesn't work because the op can't be found on binary + /// + /// + private static OpDef op_NearestNeighbors() + { + var def = new OpDef + { + Name = "NearestNeighbors" + }; + + def.InputArg.Add(new ArgDef { Name = "points", Type = DataType.DtFloat }); + def.InputArg.Add(new ArgDef { Name = "centers", Type = DataType.DtFloat }); + def.InputArg.Add(new ArgDef { Name = "k", Type = DataType.DtInt64 }); + def.OutputArg.Add(new ArgDef { Name = "nearest_center_indices", Type = DataType.DtInt64 }); + def.OutputArg.Add(new ArgDef { Name = "nearest_center_distances", Type = DataType.DtFloat }); + + return def; + } } } diff --git a/test/TensorFlowNET.Examples/KMeansClustering.cs b/test/TensorFlowNET.Examples/KMeansClustering.cs index b09e38fc..d18a1153 100644 --- a/test/TensorFlowNET.Examples/KMeansClustering.cs +++ b/test/TensorFlowNET.Examples/KMeansClustering.cs @@ -33,6 +33,8 @@ namespace TensorFlowNET.Examples public bool Run() { + tf.train.import_meta_graph("kmeans.meta"); + // Input images var X = tf.placeholder(tf.float32, shape: new TensorShape(-1, num_features)); // Labels (for assigning a label to a centroid and testing)