|
|
|
@@ -0,0 +1,168 @@ |
|
|
|
using System;
|
|
|
|
using System.Collections.Generic;
|
|
|
|
using System.Text;
|
|
|
|
using Microsoft.VisualStudio.TestTools.UnitTesting;
|
|
|
|
using Tensorflow;
|
|
|
|
|
|
|
|
namespace TensorFlowNET.UnitTest
|
|
|
|
{
|
|
|
|
/// <summary>
|
|
|
|
/// excerpt of tensorflow/python/framework/ops_test.py
|
|
|
|
/// # These cases test the private Graph._create_op_from_tf_operation
|
|
|
|
/// # method. Arguably we should only test the public APIs that depend on this
|
|
|
|
/// # method. However, this logic is complex and tricky, and it can be difficult to
|
|
|
|
/// # ascertain if we have adequate coverage (e.g. a graph may run successfully if
|
|
|
|
/// # the control flow context isn't set properly, but a more complicated use case
|
|
|
|
/// # that might not be obvious to test will fail). Thus we instead explicitly test
|
|
|
|
/// # the low-level behavior.
|
|
|
|
/// </summary>
|
|
|
|
[TestClass]
|
|
|
|
public class CreateOpFromTfOperationTest : PythonTest
|
|
|
|
{
|
|
|
|
|
|
|
|
[TestMethod]
|
|
|
|
public void TestShape()
|
|
|
|
{
|
|
|
|
var graph = tf.Graph().as_default();
|
|
|
|
with<Graph>(graph, g =>
|
|
|
|
{
|
|
|
|
var x = constant_op.constant(new [,] { {1, 2, 3}, {4, 5, 6}});
|
|
|
|
var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
|
|
|
|
var op = g._create_op_from_tf_operation(c_op);
|
|
|
|
|
|
|
|
Assert.AreEqual("myop", op.name);
|
|
|
|
Assert.AreEqual("Identity", op.type);
|
|
|
|
Assert.AreEqual(1, len(op.outputs));
|
|
|
|
AssertItemsEqual(new []{2, 3}, op.outputs[0].shape);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
/*def testUniqueName(self):
|
|
|
|
g = ops.Graph()
|
|
|
|
with g.as_default():
|
|
|
|
c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
|
|
|
|
c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
|
|
|
|
op = g._create_op_from_tf_operation(c_op)
|
|
|
|
op2 = g._create_op_from_tf_operation(c_op2)
|
|
|
|
|
|
|
|
# Create ops with same names as op1 and op2. We expect the new names to be
|
|
|
|
# uniquified.
|
|
|
|
op3 = test_ops.int_output(name="myop").op
|
|
|
|
op4 = test_ops.int_output(name="myop_1").op
|
|
|
|
|
|
|
|
self.assertEqual(op.name, "myop")
|
|
|
|
self.assertEqual(op2.name, "myop_1")
|
|
|
|
self.assertEqual(op3.name, "myop_2")
|
|
|
|
self.assertEqual(op4.name, "myop_1_1")
|
|
|
|
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
|
|
def testCond(self):
|
|
|
|
g = ops.Graph()
|
|
|
|
with g.as_default():
|
|
|
|
x = test_ops.int_output()
|
|
|
|
|
|
|
|
def true_fn():
|
|
|
|
ops._create_c_op(ops.get_default_graph(),
|
|
|
|
ops._NodeDef("IntInput", "cond/myop"), [x], [])
|
|
|
|
new_ops = g._add_new_tf_operations()
|
|
|
|
self.assertEqual(len(new_ops), 1)
|
|
|
|
return x
|
|
|
|
|
|
|
|
control_flow_ops.cond(x < 10, true_fn, lambda: x)
|
|
|
|
|
|
|
|
op = g.get_operation_by_name("cond/myop")
|
|
|
|
self.assertIsNotNone(op)
|
|
|
|
self.assertEqual(op.name, "cond/myop")
|
|
|
|
self.assertEqual(op.type, "IntInput")
|
|
|
|
self.assertEqual(op.outputs, [])
|
|
|
|
op_input = op.inputs[0].op
|
|
|
|
self.assertEqual(op_input.type, "Switch")
|
|
|
|
self.assertEqual(op_input.inputs[0], x)
|
|
|
|
self.assertEqual(op.graph, g)
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
self.assertIsNotNone(op._get_control_flow_context())
|
|
|
|
self.assertEqual(op._get_control_flow_context().name,
|
|
|
|
"cond/cond_text")
|
|
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
|
|
def testWhileLoop(self):
|
|
|
|
g = ops.Graph()
|
|
|
|
with g.as_default():
|
|
|
|
x = test_ops.int_output()
|
|
|
|
|
|
|
|
def body(i):
|
|
|
|
ops._create_c_op(ops.get_default_graph(),
|
|
|
|
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
|
|
|
|
new_ops = g._add_new_tf_operations()
|
|
|
|
self.assertEqual(len(new_ops), 1)
|
|
|
|
return i
|
|
|
|
|
|
|
|
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
|
|
|
|
|
|
|
|
op = g.get_operation_by_name("myloop/myop")
|
|
|
|
self.assertIsNotNone(op)
|
|
|
|
self.assertEqual(op.name, "myloop/myop")
|
|
|
|
self.assertEqual(op.type, "IntInput")
|
|
|
|
self.assertEqual(op.outputs, [])
|
|
|
|
op_input = op.inputs[0].op
|
|
|
|
self.assertEqual(op_input.type, "Enter")
|
|
|
|
self.assertEqual(list(op_input.inputs), [x])
|
|
|
|
self.assertEqual(op.graph, g)
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
self.assertIsNotNone(op._get_control_flow_context())
|
|
|
|
self.assertEqual(op._get_control_flow_context().name,
|
|
|
|
"myloop/while_context")
|
|
|
|
# pylint: enable=protected-access
|
|
|
|
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
|
|
def testWhileLoopWithInternalControlDep(self):
|
|
|
|
g = ops.Graph()
|
|
|
|
with g.as_default():
|
|
|
|
x = test_ops.int_output()
|
|
|
|
|
|
|
|
def body(i):
|
|
|
|
c = constant_op.constant(1.0, name="c")
|
|
|
|
ops._create_c_op(ops.get_default_graph(),
|
|
|
|
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
|
|
|
|
with ops.control_dependencies([c]):
|
|
|
|
new_ops = g._add_new_tf_operations()
|
|
|
|
self.assertEqual(len(new_ops), 1)
|
|
|
|
return i
|
|
|
|
|
|
|
|
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
|
|
|
|
|
|
|
|
op = g.get_operation_by_name("myloop/myop")
|
|
|
|
self.assertIsNotNone(op)
|
|
|
|
c = g.get_operation_by_name("myloop/c")
|
|
|
|
self.assertIsNotNone(c)
|
|
|
|
# Internal control dep is preserved
|
|
|
|
self.assertEqual(op.control_inputs, [c])
|
|
|
|
|
|
|
|
@test_util.run_v1_only("b/120545219")
|
|
|
|
def testWhileLoopWithExternalControlDep(self):
|
|
|
|
g = ops.Graph()
|
|
|
|
with g.as_default():
|
|
|
|
x = test_ops.int_output()
|
|
|
|
c = constant_op.constant(1.0)
|
|
|
|
|
|
|
|
def body(i):
|
|
|
|
ops._create_c_op(ops.get_default_graph(),
|
|
|
|
ops._NodeDef("IntInput", "myloop/myop"), [x], [])
|
|
|
|
with ops.control_dependencies([c]):
|
|
|
|
new_ops = g._add_new_tf_operations()
|
|
|
|
self.assertEqual(len(new_ops), 1)
|
|
|
|
return i
|
|
|
|
|
|
|
|
control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
|
|
|
|
|
|
|
|
op = g.get_operation_by_name("myloop/myop")
|
|
|
|
self.assertIsNotNone(op)
|
|
|
|
# External control dep is removed and replaced with internal control dep
|
|
|
|
self.assertNotEqual(op.control_inputs[0], c.op)
|
|
|
|
self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
|
|
|
|
|
|
|
|
|
|
|
|
*/
|
|
|
|
}
|
|
|
|
}
|