Browse Source

Add missing `operator +`s

Also unit testing all the operator cases.
tags/v0.10
Antonio Cifonelli 6 years ago
parent
commit
e13c444bff
2 changed files with 160 additions and 0 deletions
  1. +5
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  2. +155
    -0
      test/TensorFlowNET.UnitTest/OperationsTest.cs

+ 5
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -23,8 +23,13 @@ namespace Tensorflow
{ {
public partial class Tensor public partial class Tensor
{ {
public static Tensor operator +(double x, Tensor y) => BinaryOpWrapper("add", x, y);
public static Tensor operator +(float x, Tensor y) => BinaryOpWrapper("add", x, y);
public static Tensor operator +(int x, Tensor y) => BinaryOpWrapper("add", x, y);
public static Tensor operator +(Tensor x, Tensor y) => BinaryOpWrapper("add", x, y); public static Tensor operator +(Tensor x, Tensor y) => BinaryOpWrapper("add", x, y);
public static Tensor operator +(Tensor x, int y) => BinaryOpWrapper("add", x, y); public static Tensor operator +(Tensor x, int y) => BinaryOpWrapper("add", x, y);
public static Tensor operator +(Tensor x, float y) => BinaryOpWrapper("add", x, y);
public static Tensor operator +(Tensor x, double y) => BinaryOpWrapper("add", x, y);


public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1); public static Tensor operator -(Tensor t1) => gen_math_ops.neg(t1);




+ 155
- 0
test/TensorFlowNET.UnitTest/OperationsTest.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using NumSharp;
using Tensorflow; using Tensorflow;
using Buffer = Tensorflow.Buffer; using Buffer = Tensorflow.Buffer;


@@ -60,5 +61,159 @@ namespace TensorFlowNET.UnitTest
Assert.AreEqual((float)o, 9.0f); Assert.AreEqual((float)o, 9.0f);
} }
} }

[TestMethod]
public void addOpTests()
{
const int rows = 2; // to avoid broadcasting effect
const int cols = 10;

#region intTest
const int firstIntVal = 2;
const int secondIntVal = 3;

var firstIntFeed = Enumerable.Repeat(firstIntVal, rows * cols).ToArray();
var secondIntFeed = Enumerable.Repeat(secondIntVal, rows * cols).ToArray();
var intResult = firstIntFeed.Sum() + secondIntFeed.Sum();

var a = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
var b = tf.placeholder(tf.int32, shape: new TensorShape(rows, cols));
var c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator +(Tensor x, Tensor y)`
c = tf.reduce_sum(tf.reduce_sum(a + b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator +(Tensor x, int y)`
c = tf.reduce_sum(tf.reduce_sum(a + secondIntVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}

// Testing `operator +(int x, Tensor y)`
c = tf.reduce_sum(tf.reduce_sum(secondIntVal + a, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstIntFeed, new Shape(rows, cols))));
Assert.AreEqual((int)o, intResult);
}
#endregion

#region floatTest
const float firstFloatVal = 2.0f;
const float secondFloatVal = 3.0f;

var firstFloatFeed = Enumerable.Repeat(firstFloatVal, rows * cols).ToArray();
var secondFloatFeed = Enumerable.Repeat(secondFloatVal, rows * cols).ToArray();
var floatResult = firstFloatFeed.Sum() + secondFloatFeed.Sum();

a = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
b = tf.placeholder(tf.float32, shape: new TensorShape(rows, cols));
c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator +(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a + b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator +(Tensor x, float y)
c = tf.reduce_sum(tf.reduce_sum(a + secondFloatVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}

// Testing `operator +(float x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(secondFloatVal + a, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstFloatFeed, new Shape(rows, cols))));
Assert.AreEqual((float)o, floatResult);
}
#endregion

#region doubleTest
const double firstDoubleVal = 2.0;
const double secondDoubleVal = 3.0;

var firstDoubleFeed = Enumerable.Repeat(firstDoubleVal, rows * cols).ToArray();
var secondDoubleFeed = Enumerable.Repeat(secondDoubleVal, rows * cols).ToArray();
var doubleResult = firstDoubleFeed.Sum() + secondDoubleFeed.Sum();

a = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
b = tf.placeholder(tf.float64, shape: new TensorShape(rows, cols));
c = tf.reduce_sum(tf.reduce_sum(tf.add(a, b), 1));

using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator +(Tensor x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(a + b, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))),
new FeedItem(b, new NDArray(secondDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator +(Tensor x, double y)
c = tf.reduce_sum(tf.reduce_sum(a + secondFloatVal, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}

// Testing `operator +(double x, Tensor y)
c = tf.reduce_sum(tf.reduce_sum(secondFloatVal + a, 1));
using (var sess = tf.Session())
{
var o = sess.run(c,
new FeedItem(a, new NDArray(firstDoubleFeed, new Shape(rows, cols))));
Assert.AreEqual((double)o, doubleResult);
}
#endregion
}
} }
} }

Loading…
Cancel
Save