Browse Source

Fixed unit tests

tags/v0.12
Eli Belash 6 years ago
parent
commit
605a05eef5
5 changed files with 139 additions and 138 deletions
  1. +17
    -17
      test/TensorFlowNET.UnitTest/ConstantTest.cs
  2. +3
    -3
      test/TensorFlowNET.UnitTest/GradientTest.cs
  3. +111
    -111
      test/TensorFlowNET.UnitTest/OperationsTest.cs
  4. +1
    -1
      test/TensorFlowNET.UnitTest/PlaceholderTest.cs
  5. +7
    -6
      test/TensorFlowNET.UnitTest/VariableTest.cs

+ 17
- 17
test/TensorFlowNET.UnitTest/ConstantTest.cs View File

@@ -98,9 +98,9 @@ namespace TensorFlowNET.UnitTest
{ {
var result = sess.run(tensor); var result = sess.run(tensor);


Assert.AreEqual(result[0].shape[0], 3);
Assert.AreEqual(result[0].shape[1], 2);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result[0].Data<int>()));
Assert.AreEqual(result.shape[0], 3);
Assert.AreEqual(result.shape[1], 2);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data<int>()));
} }


// big size // big size
@@ -109,13 +109,13 @@ namespace TensorFlowNET.UnitTest
{ {
var result = sess.run(tensor); var result = sess.run(tensor);


Assert.AreEqual(result[0].shape[0], 200);
Assert.AreEqual(result[0].shape[1], 100);
Assert.AreEqual(result.shape[0], 200);
Assert.AreEqual(result.shape[1], 100);


var data = result[0].Data<int>();
var data = result.Data<int>();
Assert.AreEqual(0, data[0]); Assert.AreEqual(0, data[0]);
Assert.AreEqual(0, data[500]); Assert.AreEqual(0, data[500]);
Assert.AreEqual(0, data[result[0].size - 1]);
Assert.AreEqual(0, data[result.size - 1]);
} }
} }


@@ -127,9 +127,9 @@ namespace TensorFlowNET.UnitTest
{ {
var result = sess.run(ones); var result = sess.run(ones);


Assert.AreEqual(result[0].shape[0], 3);
Assert.AreEqual(result[0].shape[1], 2);
Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result[0].Data<int>()));
Assert.AreEqual(result.shape[0], 3);
Assert.AreEqual(result.shape[1], 2);
Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result.Data<int>()));
} }
} }


@@ -142,9 +142,9 @@ namespace TensorFlowNET.UnitTest
{ {
var result = sess.run(halfes); var result = sess.run(halfes);


Assert.AreEqual(result[0].shape[0], 3);
Assert.AreEqual(result[0].shape[1], 2);
Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result[0].Data<double>()));
Assert.AreEqual(result.shape[0], 3);
Assert.AreEqual(result.shape[1], 2);
Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result.Data<double>()));
} }
} }


@@ -161,10 +161,10 @@ namespace TensorFlowNET.UnitTest
using (var sess = tf.Session()) using (var sess = tf.Session())
{ {
var result = sess.run(tensor); var result = sess.run(tensor);
var data = result[0].Data<int>();
var data = result.Data<int>();


Assert.AreEqual(result[0].shape[0], 2);
Assert.AreEqual(result[0].shape[1], 3);
Assert.AreEqual(result.shape[0], 2);
Assert.AreEqual(result.shape[1], 3);
Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data)); Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
} }
} }
@@ -177,7 +177,7 @@ namespace TensorFlowNET.UnitTest
var c = a * b; var c = a * b;


var sess = tf.Session(); var sess = tf.Session();
double result = sess.run(c)[0];
double result = sess.run(c);
sess.close(); sess.close();


Assert.AreEqual(6.0, result); Assert.AreEqual(6.0, result);


+ 3
- 3
test/TensorFlowNET.UnitTest/GradientTest.cs View File

@@ -41,7 +41,7 @@ namespace TensorFlowNET.UnitTest
var grad = tf.gradients(y, x); var grad = tf.gradients(y, x);
Assert.AreEqual(grad[0].name, "gradients/AddN:0"); Assert.AreEqual(grad[0].name, "gradients/AddN:0");


float r = sess.run(grad[0])[0];
float r = sess.run(grad[0]);
Assert.AreEqual(r, 1.4f); Assert.AreEqual(r, 1.4f);
} }
} }
@@ -57,7 +57,7 @@ namespace TensorFlowNET.UnitTest
var grad = tf.gradients(y, x); var grad = tf.gradients(y, x);
Assert.AreEqual(grad[0].name, "gradients/AddN:0"); Assert.AreEqual(grad[0].name, "gradients/AddN:0");


float r = sess.run(grad[0])[0];
float r = sess.run(grad[0]);
Assert.AreEqual(r, 14.700001f); Assert.AreEqual(r, 14.700001f);
}); });
} }
@@ -94,7 +94,7 @@ namespace TensorFlowNET.UnitTest


using (var sess = tf.Session(graph)) using (var sess = tf.Session(graph))
{ {
var r = sess.run(slice)[0];
var r = sess.run(slice);


Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 })); Assert.IsTrue(Enumerable.SequenceEqual(r.shape, new[] { 2, 1, 2 }));
Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData<int>(), new[] { 11, 13 })); Assert.IsTrue(Enumerable.SequenceEqual(r[0].GetData<int>(), new[] { 11, 13 }));


+ 111
- 111
test/TensorFlowNET.UnitTest/OperationsTest.cs
File diff suppressed because it is too large
View File


+ 1
- 1
test/TensorFlowNET.UnitTest/PlaceholderTest.cs View File

@@ -17,7 +17,7 @@ namespace TensorFlowNET.UnitTest
{ {
var result = sess.run(y, var result = sess.run(y,
new FeedItem(x, 2)); new FeedItem(x, 2));
Assert.AreEqual((int)result[0], 6);
Assert.AreEqual((int)result, 6);
} }
} }
} }


+ 7
- 6
test/TensorFlowNET.UnitTest/VariableTest.cs View File

@@ -1,4 +1,5 @@
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using NumSharp;
using Tensorflow; using Tensorflow;
using static Tensorflow.Binding; using static Tensorflow.Binding;


@@ -16,7 +17,7 @@ namespace TensorFlowNET.UnitTest
{ {
session.run(x.initializer); session.run(x.initializer);
var result = session.run(x); var result = session.run(x);
Assert.AreEqual(10, (int)result[0]);
Assert.AreEqual(10, (int)result);
} }
} }


@@ -81,7 +82,7 @@ namespace TensorFlowNET.UnitTest
using (var session = tf.Session()) using (var session = tf.Session())
{ {
session.run(model); session.run(model);
int result = session.run(y)[0];
int result = session.run(y);
Assert.AreEqual(result, 4); Assert.AreEqual(result, 4);
} }
} }
@@ -97,12 +98,12 @@ namespace TensorFlowNET.UnitTest
var sess = tf.Session(graph); var sess = tf.Session(graph);
sess.run(init); sess.run(init);


var result = sess.run(variable);
Assert.IsTrue((int)result[0] == 31);
NDArray result = sess.run(variable);
Assert.IsTrue((int)result == 31);


var assign = variable.assign(12); var assign = variable.assign(12);
result = sess.run(assign); result = sess.run(assign);
Assert.IsTrue((int)result[0] == 12);
Assert.IsTrue((int)result == 12);
} }


[TestMethod] [TestMethod]
@@ -139,7 +140,7 @@ namespace TensorFlowNET.UnitTest
for(int i = 0; i < 5; i++) for(int i = 0; i < 5; i++)
{ {
x = x + 1; x = x + 1;
result = session.run(x)[0];
result = session.run(x);
print(result); print(result);
} }
} }


Loading…
Cancel
Save