| @@ -52,13 +52,6 @@ namespace Tensorflow.Clustering | |||||
| _num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray()); | _num_data = math_ops.add_n(_inputs.Select(i => array_ops.shape(i)[0]).ToArray()); | ||||
| } | } | ||||
| public Tensor[] op() | |||||
| { | |||||
| return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0), | |||||
| () => new Operation[] { check_ops.assert_equal(_cluster_centers_initialized, true) }, | |||||
| _initialize); | |||||
| } | |||||
| private Operation[] _initialize() | private Operation[] _initialize() | ||||
| { | { | ||||
| with(ops.control_dependencies(new Operation[] | with(ops.control_dependencies(new Operation[] | ||||
| @@ -72,6 +65,17 @@ namespace Tensorflow.Clustering | |||||
| throw new NotImplementedException("_InitializeClustersOpFactory _initialize"); | throw new NotImplementedException("_InitializeClustersOpFactory _initialize"); | ||||
| } | } | ||||
| public Tensor[] op() | |||||
| { | |||||
| return control_flow_ops.cond(gen_math_ops.equal(_num_remaining, 0), | |||||
| () => | |||||
| { | |||||
| var op = check_ops.assert_equal(_cluster_centers_initialized, true); | |||||
| return new Operation[] { op }; | |||||
| }, | |||||
| _initialize); | |||||
| } | |||||
| /*private int _add_new_centers() | /*private int _add_new_centers() | ||||
| { | { | ||||
| var new_centers = _choose_initial_centers(); | var new_centers = _choose_initial_centers(); | ||||
| @@ -25,7 +25,7 @@ namespace Tensorflow | |||||
| public void _add_control_input(Operation op) | public void _add_control_input(Operation op) | ||||
| { | { | ||||
| c_api.TF_AddControlInput(_handle, op); | |||||
| c_api.TF_AddControlInput(_operDesc, op); | |||||
| } | } | ||||
| public void _add_control_inputs(Operation[] ops) | public void _add_control_inputs(Operation[] ops) | ||||
| @@ -11,6 +11,7 @@ namespace Tensorflow | |||||
| public partial class Operation : ITensorOrOperation | public partial class Operation : ITensorOrOperation | ||||
| { | { | ||||
| private readonly IntPtr _handle; // _c_op in python | private readonly IntPtr _handle; // _c_op in python | ||||
| private readonly IntPtr _operDesc; | |||||
| private Graph _graph; | private Graph _graph; | ||||
| //[JsonIgnore] | //[JsonIgnore] | ||||
| @@ -58,9 +59,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| _graph = g; | _graph = g; | ||||
| var desc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
| c_api.TF_SetAttrType(desc, "dtype", TF_DataType.TF_INT32); | |||||
| c_api.TF_FinishOperation(desc, status); | |||||
| _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
| c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | |||||
| _handle = c_api.TF_FinishOperation(_operDesc, status); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -112,7 +113,7 @@ namespace Tensorflow | |||||
| op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
| _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
| (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
| // Initialize self._outputs. | // Initialize self._outputs. | ||||
| output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
| @@ -12,12 +12,29 @@ namespace Tensorflow | |||||
| /// <param name="t1"></param> | /// <param name="t1"></param> | ||||
| /// <param name="t2"></param> | /// <param name="t2"></param> | ||||
| /// <param name="name"></param> | /// <param name="name"></param> | ||||
| public static Operation assert_equal(object t1, object t2, object[] data = null, string name = null) | |||||
| public static Operation assert_equal(object t1, object t2, object[] data = null, string message = null, string name = null) | |||||
| { | { | ||||
| if (message == null) | |||||
| message = ""; | |||||
| return with(ops.name_scope(name, "assert_equal", new { t1, t2, data }), delegate | return with(ops.name_scope(name, "assert_equal", new { t1, t2, data }), delegate | ||||
| { | { | ||||
| var x = ops.convert_to_tensor(t1, name: "x"); | var x = ops.convert_to_tensor(t1, name: "x"); | ||||
| var y = ops.convert_to_tensor(t2, name: "y"); | var y = ops.convert_to_tensor(t2, name: "y"); | ||||
| if (data == null) | |||||
| { | |||||
| data = new object[] | |||||
| { | |||||
| message, | |||||
| "Condition x == y did not hold element-wise:", | |||||
| $"x (%s) = {x.name}", | |||||
| x, | |||||
| $"y (%s) = {y.name}", | |||||
| y | |||||
| }; | |||||
| } | |||||
| var condition = math_ops.reduce_all(gen_math_ops.equal(x, y)); | var condition = math_ops.reduce_all(gen_math_ops.equal(x, y)); | ||||
| var x_static = tensor_util.constant_value(x); | var x_static = tensor_util.constant_value(x); | ||||
| var y_static = tensor_util.constant_value(y); | var y_static = tensor_util.constant_value(y); | ||||
| @@ -265,7 +265,7 @@ namespace Tensorflow | |||||
| public override string ToString() | public override string ToString() | ||||
| { | { | ||||
| return $"tf.Variable '{name}' shape={shape} dtype={dtype}"; | |||||
| return $"tf.RefVariable '{name}' shape={shape} dtype={dtype}"; | |||||
| } | } | ||||
| public VariableDef to_proto(string export_scope) | public VariableDef to_proto(string export_scope) | ||||
| @@ -122,7 +122,7 @@ namespace Tensorflow | |||||
| /// </param> | /// </param> | ||||
| /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | ||||
| /// <returns>A wrapped TF_Operation*.</returns> | /// <returns>A wrapped TF_Operation*.</returns> | ||||
| public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
| public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
| { | { | ||||
| var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | var op_desc = graph.NewOperation(node_def.Op, node_def.Name); | ||||
| @@ -164,7 +164,7 @@ namespace Tensorflow | |||||
| status.Check(true); | status.Check(true); | ||||
| return c_op; | |||||
| return (c_op, op_desc); | |||||
| } | } | ||||
| public static OpDef _get_op_def(Graph graph, string type) | public static OpDef _get_op_def(Graph graph, string type) | ||||
| @@ -16,7 +16,7 @@ namespace TensorFlowNET.Examples | |||||
| public class KMeansClustering : Python, IExample | public class KMeansClustering : Python, IExample | ||||
| { | { | ||||
| public int Priority => 8; | public int Priority => 8; | ||||
| public bool Enabled => false; | |||||
| public bool Enabled => true; | |||||
| public string Name => "K-means Clustering"; | public string Name => "K-means Clustering"; | ||||
| Datasets mnist; | Datasets mnist; | ||||