Browse Source

Add missing `operator >`s

Also unit testing all the operator cases.
pull/329/head
Antonio Cifonelli 6 years ago
parent
commit
358b183c5f
2 changed files with 158 additions and 0 deletions
  1. +4
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  2. +154
    -0
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 4
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -82,6 +82,10 @@ namespace Tensorflow

public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y);

public static Tensor operator >(double x, Tensor y) => gen_math_ops.greater(x, y);
public static Tensor operator >(float x, Tensor y) => gen_math_ops.greater(x, y);
public static Tensor operator >(int x, Tensor y) => gen_math_ops.greater(x, y);
public static Tensor operator >(Tensor x, Tensor y) => gen_math_ops.greater(x, y);
public static Tensor operator >(Tensor x, int y) => gen_math_ops.greater(x, y);
public static Tensor operator >=(Tensor x, Tensor y) => gen_math_ops.greater_equal(x, y);
public static Tensor operator >(Tensor x, float y) => gen_math_ops.greater(x, y);


+ 154
- 0
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -833,5 +833,159 @@ namespace TensorFlowNET.UnitTest
}
#endregion
}

[TestMethod]
public void greaterThanOpTests()
{
const int rows = 2; // to avoid broadcasting effect
const int cols = 10;

#region intTest
const int intThreshold = 10;

var firstIntFeed = Enumerable.Range(0, rows * cols).ToArray();
var secondIntFeed = Enumerable.Repeat(intThreshold, rows * cols).ToArray();
var intResult = firstIntFeed.Count(elem => elem > intThreshold);
var intResultTwo = firstIntFeed.Count(elem => elem < intThreshold);

var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
var c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator >(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator >(Tensor x, int y)
c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > intThreshold, tf.int32), 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator >(int x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(tf.cast(intThreshold > a, tf.int32), 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResultTwo);
}
#endregion

#region floatTest
const float floatThreshold = 10.0f;

var firstFloatFeed = Enumerable.Range(0, rows * cols).Select(elem => (float)elem).ToArray();
var secondFloatFeed = Enumerable.Repeat(floatThreshold, rows * cols).ToArray();
var floatResult = firstFloatFeed.Count(elem => elem > floatThreshold);
var floatResultTwo = firstFloatFeed.Count(elem => elem < floatThreshold);

a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, floatResult);
}

// Testing `operator >(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, floatResult);
}

// Testing `operator >(Tensor x, float y)
c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > floatThreshold, tf.int32), 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, floatResult);
}

// Testing `operator >(float x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(tf.cast(floatThreshold > a, tf.int32), 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, floatResultTwo);
}
#endregion

#region doubleTest
const double doubleThreshold = 10.0;

var firstDoubleFeed = Enumerable.Repeat(0, rows * cols).Select(elem => (double)elem).ToArray();
var secondDoubleFeed = Enumerable.Repeat(doubleThreshold, rows * cols).ToArray();
var doubleResult = firstDoubleFeed.Count(elem => elem > doubleThreshold);
var doubleResultTwo = firstDoubleFeed.Count(elem => elem < doubleThreshold);

a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
c = tf.reduce_sum(tf.reduce_sum(tf.cast(tf.greater(a, b), tf.int32), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, doubleResult);
}

// Testing `operator >(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > b, tf.int32), 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, doubleResult);
}

// Testing `operator >(Tensor x, double y)
c = tf.reduce_sum(tf.reduce_sum(tf.cast(a > doubleThreshold, tf.int32), 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, doubleResult);
}

// Testing `operator >(double x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(tf.cast(doubleThreshold > a, tf.int32), 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, doubleResultTwo);
}
#endregion
}
}
}

Loading…
Cancel
Save