diff --git a/src/TensorFlowNET.Core/APIs/tf.debugging.cs b/src/TensorFlowNET.Core/APIs/tf.debugging.cs index 8e220594..579698a2 100644 --- a/src/TensorFlowNET.Core/APIs/tf.debugging.cs +++ b/src/TensorFlowNET.Core/APIs/tf.debugging.cs @@ -40,5 +40,15 @@ namespace Tensorflow message: message, name: name); + public Tensor assert_greater_equal(Tensor x, + Tensor y, + object[] data = null, + string message = null, + string name = null) + => check_ops.assert_greater_equal(x, + y, + data: data, + message: message, + name: name); } } diff --git a/src/TensorFlowNET.Core/Operations/check_ops.cs b/src/TensorFlowNET.Core/Operations/check_ops.cs index ef2ea3b6..6312aa9e 100644 --- a/src/TensorFlowNET.Core/Operations/check_ops.cs +++ b/src/TensorFlowNET.Core/Operations/check_ops.cs @@ -56,6 +56,37 @@ namespace Tensorflow return control_flow_ops.Assert(condition, data); }); } + + public static Operation assert_greater_equal(Tensor x, Tensor y, object[] data = null, string message = null, + string name = null) + { + if (message == null) + message = ""; + + return tf_with(ops.name_scope(name, "assert_greater_equal", new {x, y, data}), delegate + { + x = ops.convert_to_tensor(x, name: "x"); + y = ops.convert_to_tensor(y, name: "y"); + string x_name = x.name; + string y_name = y.name; + if (data == null) + { + data = new object[] + { + message, + "Condition x >= y did not hold element-wise:", + $"x (%s) = {x_name}", + x, + $"y (%s) = {y_name}", + y + }; + } + + var condition = math_ops.reduce_all(gen_math_ops.greater_equal(x, y)); + return control_flow_ops.Assert(condition, data); + }); + } + public static Operation assert_positive(Tensor x, object[] data = null, string message = null, string name = null) {