Browse Source

cond: added more testcases

tags/v0.9
Meinrad Recheis 6 years ago
parent
commit
82af34f411
6 changed files with 1413 additions and 3 deletions
  1. +8
    -1
      src/TensorFlowNET.Core/Operations/control_flow_util.py.cs
  2. +54
    -2
      test/TensorFlowNET.UnitTest/PythonTest.cs
  3. +107
    -0
      test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs
  4. +23
    -0
      test/TensorFlowNET.UnitTest/control_flow_ops_test/ShapeTestCase.cs
  5. +162
    -0
      test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs
  6. +1059
    -0
      test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py

+ 8
- 1
src/TensorFlowNET.Core/Operations/control_flow_util.py.cs View File

@@ -27,10 +27,17 @@ namespace Tensorflow
return op.type == "Switch" || op.type == "RefSwitch"; return op.type == "Switch" || op.type == "RefSwitch";
} }


/// <summary>
/// Return the control flow context for the output of an op.
/// </summary>
public static IControlFlowContext GetOutputContext(Operation op) public static IControlFlowContext GetOutputContext(Operation op)
{ {
var ctxt = op._get_control_flow_context(); var ctxt = op._get_control_flow_context();

// Exit nodes usually have a control flow context, except in the case where the
// exit node was imported via import_graph_def (in which case no nodes have
// control flow contexts).
if (ctxt != null && IsLoopExit(op))
ctxt = ctxt.outer_context;
return ctxt; return ctxt;
} }
} }


+ 54
- 2
test/TensorFlowNET.UnitTest/PythonTest.cs View File

@@ -5,6 +5,7 @@ using System.Linq;
using System.Text; using System.Text;
using Microsoft.VisualStudio.TestTools.UnitTesting; using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow; using Tensorflow;
using Tensorflow.Util;
namespace TensorFlowNET.UnitTest namespace TensorFlowNET.UnitTest
{ {
@@ -13,6 +14,15 @@ namespace TensorFlowNET.UnitTest
/// </summary> /// </summary>
public class PythonTest : Python public class PythonTest : Python
{ {
#region python compatibility layer
protected PythonTest self { get => this; }
protected object None {
get { return null; }
}
#endregion
#region pytest assertions
public void assertItemsEqual(ICollection given, ICollection expected) public void assertItemsEqual(ICollection given, ICollection expected)
{ {
Assert.IsNotNull(expected); Assert.IsNotNull(expected);
@@ -20,20 +30,62 @@ namespace TensorFlowNET.UnitTest
var e = expected.OfType<object>().ToArray(); var e = expected.OfType<object>().ToArray();
var g = given.OfType<object>().ToArray(); var g = given.OfType<object>().ToArray();
Assert.AreEqual(e.Length, g.Length, $"The collections differ in length expected {e.Length} but got {g.Length}"); Assert.AreEqual(e.Length, g.Length, $"The collections differ in length expected {e.Length} but got {g.Length}");
for(int i=0; i<e.Length; i++)
for (int i = 0; i < e.Length; i++)
Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}"); Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}");
} }
public void assertEqual(object given, object expected) public void assertEqual(object given, object expected)
{ {
if (given is ICollection && expected is ICollection)
{
assertItemsEqual(given as ICollection, expected as ICollection);
return;
}
Assert.AreEqual(expected, given); Assert.AreEqual(expected, given);
} }
public void assertEquals(object given, object expected)
{
assertEqual(given, expected);
}
public void assertIsNotNone(object given) public void assertIsNotNone(object given)
{ {
Assert.IsNotNull(given); Assert.IsNotNull(given);
} }
protected PythonTest self { get => this; }
#endregion
#region tensor evaluation
protected object _eval_helper(Tensor[] tensors)
{
if (tensors == null)
return null;
//return nest.map_structure(self._eval_tensor, tensors);
return null;
}
//def evaluate(self, tensors) :
// """Evaluates tensors and returns numpy values.
// Args:
// tensors: A Tensor or a nested list/tuple of Tensors.
// Returns:
// tensors numpy values.
// """
// if context.executing_eagerly():
// return self._eval_helper(tensors)
// else:
// sess = ops.get_default_session()
// if sess is None:
// with self.test_session() as sess:
// return sess.run(tensors)
// else:
// return sess.run(tensors)
#endregion
} }
} }

+ 107
- 0
test/TensorFlowNET.UnitTest/control_flow_ops_test/CondTestCases.cs View File

@@ -0,0 +1,107 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
namespace TensorFlowNET.UnitTest.control_flow_ops_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
/// </summary>
[TestClass]
public class CondTestCases : PythonTest
{
[Ignore("Todo")]
[TestMethod]
public void testCondTrue()
{
//var x = constant_op.constant(2);
//var y = constant_op.constant(5);
// var z = control_flow_ops.cond(math_ops.less(x,y), ()=> math_ops.multiply(x, 17), ()=> math_ops.add(y, 23))
//self.assertEquals(self.evaluate(z), 34);
}
[Ignore("Todo")]
[TestMethod]
public void testCondFalse()
{
// def testCondFalse(self):
// x = constant_op.constant(2)
// y = constant_op.constant(1)
// z = control_flow_ops.cond(
// math_ops.less(
// x,
// y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 24)
}
[Ignore("Todo")]
[TestMethod]
public void testCondTrueLegacy()
{
// def testCondTrueLegacy(self):
// x = constant_op.constant(2)
// y = constant_op.constant(5)
// z = control_flow_ops.cond(
// math_ops.less(x, y),
// fn1=lambda: math_ops.multiply(x, 17),
// fn2=lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 34)
}
[Ignore("Todo")]
[TestMethod]
public void testCondFalseLegacy()
{
// def testCondFalseLegacy(self):
// x = constant_op.constant(2)
// y = constant_op.constant(1)
// z = control_flow_ops.cond(
// math_ops.less(x, y),
// fn1=lambda: math_ops.multiply(x, 17),
// fn2=lambda: math_ops.add(y, 23))
// self.assertEquals(self.evaluate(z), 24)
}
[Ignore("Todo")]
[TestMethod]
public void testCondMissingArg1()
{
// def testCondMissingArg1(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, false_fn=lambda: x)
}
[Ignore("Todo")]
[TestMethod]
public void testCondMissingArg2()
{
// def testCondMissingArg2(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, lambda: x)
}
[Ignore("Todo")]
[TestMethod]
public void testCondDuplicateArg1()
{
// def testCondDuplicateArg1(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
}
[Ignore("Todo")]
[TestMethod]
public void testCondDuplicateArg2()
{
// def testCondDuplicateArg2(self):
// x = constant_op.constant(1)
// with self.assertRaises(TypeError):
// control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
}
}
}

+ 23
- 0
test/TensorFlowNET.UnitTest/control_flow_ops_test/ShapeTestCase.cs View File

@@ -0,0 +1,23 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
namespace TensorFlowNET.UnitTest.control_flow_ops_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
/// </summary>
[TestClass]
public class ShapeTestCase : PythonTest
{
[TestMethod]
public void testShape()
{
var tensor = constant_op.constant(new[]{1.0, 2.0});
self.assertEquals(new int[] {2}, tensor.shape);
self.assertEquals(new int[] {2},
control_flow_ops.with_dependencies(new[] {constant_op.constant(1.0).op}, tensor).shape);
}
}
}

+ 162
- 0
test/TensorFlowNET.UnitTest/control_flow_ops_test/SwitchTestCase.cs View File

@@ -0,0 +1,162 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow;
namespace TensorFlowNET.UnitTest.control_flow_ops_test
{
/// <summary>
/// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
/// </summary>
[TestClass]
public class SwitchTestCase : PythonTest
{
[Ignore("TODO")]
[TestMethod]
public void testResourceReadInLoop()
{
//def testResourceReadInLoop(self):
// embedding_matrix = variable_scope.get_variable(
// "embedding_matrix", initializer=[[2.0], [3.0]], use_resource=True)
//
// def cond(it, _):
// return it < 5
//
// def body(it, cost):
// embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
// cost += math_ops.reduce_sum(embedding)
// return it + 1, cost
//
// _, cost = control_flow_ops.while_loop(
// cond, body, [constant_op.constant(0),
// constant_op.constant(0.0)])
// with self.cached_session():
// self.evaluate(variables.global_variables_initializer())
// self.assertAllEqual(10.0, self.evaluate(cost))
}
[Ignore("TODO")]
[TestMethod]
public void testIndexedSlicesGradientInCondInWhileLoop()
{
doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: false);
}
[Ignore("TODO")]
[TestMethod]
public void testIndexedSlicesGradientInCondInWhileLoopResource()
{
doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: true);
}
private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource= false)
{
//def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False):
// embedding_matrix = variable_scope.get_variable(
// "embedding_matrix", [5, 5],
// initializer=init_ops.random_normal_initializer(),
// use_resource=use_resource)
// def cond(it, _):
// return it < 5
// def body(it, cost):
// embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
// cost = control_flow_ops.cond(
// math_ops.equal(it, 3), lambda: math_ops.square(cost),
// (lambda: cost + math_ops.reduce_sum(embedding)))
// return it + 1, cost
// _, cost = control_flow_ops.while_loop(
// cond, body, [constant_op.constant(0),
// constant_op.constant(0.0)])
// dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0]
// dynamic_grads = math_ops.segment_sum(dynamic_grads.values,
// dynamic_grads.indices)
// embedding = embedding_ops.embedding_lookup(embedding_matrix, [0])
// static = math_ops.square(
// math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) +
// math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding)
// static_grads = gradients_impl.gradients(static, [embedding_matrix])[0]
// static_grads = math_ops.segment_sum(static_grads.values,
// static_grads.indices)
// with self.cached_session():
// self.evaluate(variables.global_variables_initializer())
// self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads]))
}
[Ignore("TODO")]
[TestMethod]
public void testIndexedSlicesWithShapeGradientInWhileLoop()
{
//@test_util.run_v1_only("b/120545219")
//def testIndexedSlicesWithShapeGradientInWhileLoop(self):
// for dtype in [dtypes.float32, dtypes.float64]:
// with self.cached_session() as sess:
// num_steps = 9
// inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps])
// initial_outputs = tensor_array_ops.TensorArray(
// dtype=dtype, size=num_steps)
// initial_i = constant_op.constant(0, dtype=dtypes.int32)
// def cond(i, _):
// return i < num_steps # pylint: disable=cell-var-from-loop
// def body(i, outputs):
// x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop
// outputs = outputs.write(i, x)
// return i + 1, outputs
// _, outputs = control_flow_ops.while_loop(cond, body,
// [initial_i, initial_outputs])
// outputs = math_ops.reduce_sum(outputs.stack())
// r = gradients_impl.gradients([outputs], [inputs])[0]
// grad_wr_inputs = ops.convert_to_tensor(r)
// o, grad = sess.run([outputs, grad_wr_inputs],
// feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]})
// self.assertEquals(o, 20)
// self.assertAllEqual(grad, [1] * num_steps)
}
[Ignore("TODO")]
[TestMethod]
public void testIndexedSlicesWithDynamicShapeGradientInWhileLoop()
{
//@test_util.run_v1_only("b/120545219")
//def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self):
// for dtype in [dtypes.float32, dtypes.float64]:
// with self.cached_session() as sess:
// inputs = array_ops.placeholder(dtype=dtype)
// initial_outputs = tensor_array_ops.TensorArray(
// dtype=dtype, dynamic_size=True, size=1)
// initial_i = constant_op.constant(0, dtype=dtypes.int32)
// def cond(i, _):
// return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop
// def body(i, outputs):
// x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop
// outputs = outputs.write(i, x)
// return i + 1, outputs
// _, outputs = control_flow_ops.while_loop(cond, body,
// [initial_i, initial_outputs])
// outputs = math_ops.reduce_sum(outputs.stack())
// r = gradients_impl.gradients([outputs], [inputs])[0]
// grad_wr_inputs = ops.convert_to_tensor(r)
// o, grad = sess.run([outputs, grad_wr_inputs],
// feed_dict={inputs: [1, 3, 2]})
// self.assertEquals(o, 6)
// self.assertAllEqual(grad, [1] * 3)
}
}
}

+ 1059
- 0
test/TensorFlowNET.UnitTest/control_flow_ops_test/control_flow_ops_test.py
File diff suppressed because it is too large
View File


Loading…
Cancel
Save