From 1c30b2b8cbb987f17d29270948f28f21427bcb80 Mon Sep 17 00:00:00 2001 From: Antonio Cifonelli Date: Fri, 19 Jul 2019 17:16:07 +0200 Subject: [PATCH] Add missing `operator >=`s Also unit testing the new operators. --- .../Tensors/Tensor.Operators.cs | 9 +- test/TensorFlowNET.UnitTest/OperationsTest.cs | 154 ++++++++++++++++++ 2 files changed, 162 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs index a5a9b674..817f14c6 100644 --- a/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs +++ b/src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs @@ -87,7 +87,6 @@ namespace Tensorflow public static Tensor operator >(int x, Tensor y) => gen_math_ops.greater(x, y); public static Tensor operator >(Tensor x, Tensor y) => gen_math_ops.greater(x, y); public static Tensor operator >(Tensor x, int y) => gen_math_ops.greater(x, y); - public static Tensor operator >=(Tensor x, Tensor y) => gen_math_ops.greater_equal(x, y); public static Tensor operator >(Tensor x, float y) => gen_math_ops.greater(x, y); public static Tensor operator >(Tensor x, double y) => gen_math_ops.greater(x, y); @@ -100,6 +99,14 @@ namespace Tensorflow public static Tensor operator <(Tensor x, float y) => gen_math_ops.less(x, y); public static Tensor operator <(Tensor x, double y) => gen_math_ops.less(x, y); + public static Tensor operator >=(double x, Tensor y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(float x, Tensor y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(int x, Tensor y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(Tensor x, Tensor y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(Tensor x, int y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(Tensor x, float y) => gen_math_ops.greater_equal(x, y); + public static Tensor operator >=(Tensor x, double y) => gen_math_ops.greater_equal(x, y); + private static Tensor BinaryOpWrapper(string name, Tx x, Ty y) { TF_DataType dtype = TF_DataType.DtInvalid; diff --git a/test/TensorFlowNET.UnitTest/OperationsTest.cs b/test/TensorFlowNET.UnitTest/OperationsTest.cs index 358f3fb9..56927df5 100644 --- a/test/TensorFlowNET.UnitTest/OperationsTest.cs +++ b/test/TensorFlowNET.UnitTest/OperationsTest.cs @@ -1141,5 +1141,159 @@ namespace TensorFlowNET.UnitTest } #endregion } + + [TestMethod] + public void greaterOrEqualThanOpTests() + { + const int rows = 2; // to avoid broadcasting effect + const int cols = 10; + + #region intTest + const int intThreshold = 10; + + var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray(); + var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray(); + var intResult = firstIntFeed.Count(elem => elem >= intThreshold); + var intResultTwo = firstIntFeed.Count(elem => elem <= intThreshold); + + var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols)); + var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater_equal(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator >=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator >=(Tensor x, int y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= intThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResult); + } + + // Testing `operator >=(int x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold >= a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, intResultTwo); + } + #endregion + + #region floatTest + const float floatThreshold = 10.0f; + + var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray(); + var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray(); + var floatResult = firstFloatFeed.Count(elem => elem >= floatThreshold); + var floatResultTwo = firstFloatFeed.Count(elem => elem <= floatThreshold); + + a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater_equal(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator >=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator >=(Tensor x, float y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= floatThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResult); + } + + // Testing `operator >=(float x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold >= a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, floatResultTwo); + } + #endregion + + #region doubleTest + const double doubleThreshold = 10.0; + + var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray(); + var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray(); + var doubleResult = firstDoubleFeed.Count(elem => elem >= doubleThreshold); + var doubleResultTwo = firstDoubleFeed.Count(elem => elem <= doubleThreshold); + + a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols)); + c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater_equal(a, b), tf.int32), 1)); + + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator >=(Tensor x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= b, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))), + new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator >=(Tensor x, double y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(a >= doubleThreshold, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResult); + } + + // Testing `operator >=(double x, Tensor y) + c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold >= a, tf.int32), 1)); + using (var sess = tf.Session()) + { + var o = sess.run(c, + new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols)))); + Assert.AreEqual((int)o, doubleResultTwo); + } + #endregion + } } }