diff --git a/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs index 389f0206..62cdbba0 100644 --- a/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs +++ b/test/TensorFlowNET.UnitTest/ControlDependenciesTest.cs @@ -70,7 +70,7 @@ namespace TensorFlowNET.UnitTest { a = constant_op.constant(1.0); var b1 = future(); - with(g.control_dependencies(new [] { a, b}), ctrl => + with(g.control_dependencies(new[] { a, b }), ctrl => { c = constant_op.constant(3.0); }); @@ -170,155 +170,111 @@ namespace TensorFlowNET.UnitTest var a_3 = constant_op.constant(4.0); var a_4 = constant_op.constant(5.0); Operation b_3_4 = null, b_3 = null, b_none = null, b_1 = null, b_1_2 = null, b_none2 = null; - with(g.control_dependencies(new[] { a_1 }), ctrl1 => - { - with(g.control_dependencies(new[] { a_2 }), ctrl2 => - { - with(g.control_dependencies(null), ctrl3 => - { - with(g.control_dependencies(new[] { a_3 }), ctrl4 => - { - with(g.control_dependencies(new[] { a_4 }), ctrl5 => - { - // deps [a_3, a_4] - b_3_4 = constant_op.constant(7.0); - }); - // deps = [a_3] - b_3 = constant_op.constant(8.0); - }); - // deps back to None - b_none = constant_op.constant(9.0); - }); - // deps back to [a_1, a_2] - b_1_2 = constant_op.constant(10.0); - }); - // deps back to [a_1] - b_1 = constant_op.constant(11.0); - with(g.control_dependencies(null), ctrl6 => - { - // deps are None again - b_none2 = constant_op.constant(12.0); - }); - }); - AssertItemsEqual(new[] {a_3.op, a_4.op}, b_3_4.op.control_inputs); - AssertItemsEqual(new[] {a_3.op}, b_3.op.control_inputs); + with(g.control_dependencies(new[] { a_1 }), ctrl1 => + { + with(g.control_dependencies(new[] { a_2 }), ctrl2 => + { + with(g.control_dependencies(null), ctrl3 => + { + with(g.control_dependencies(new[] { a_3 }), ctrl4 => + { + with(g.control_dependencies(new[] { a_4 }), ctrl5 => + { + // deps [a_3, a_4] + b_3_4 = constant_op.constant(7.0); + }); + // deps = [a_3] + b_3 = constant_op.constant(8.0); + }); + // deps back to None + b_none = constant_op.constant(9.0); + }); + // deps back to [a_1, a_2] + b_1_2 = constant_op.constant(10.0); + }); + // deps back to [a_1] + b_1 = constant_op.constant(11.0); + with(g.control_dependencies(null), ctrl6 => + { + // deps are None again + b_none2 = constant_op.constant(12.0); + }); + }); + AssertItemsEqual(new[] { a_3.op, a_4.op }, b_3_4.op.control_inputs); + AssertItemsEqual(new[] { a_3.op }, b_3.op.control_inputs); AssertItemsEqual(new object[0], b_none.op.control_inputs); - AssertItemsEqual(new[] {a_1.op, a_2.op}, b_1_2.op.control_inputs); - AssertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs); + AssertItemsEqual(new[] { a_1.op, a_2.op }, b_1_2.op.control_inputs); + AssertItemsEqual(new[] { a_1.op }, b_1.op.control_inputs); AssertItemsEqual(new object[0], b_none2.op.control_inputs); - /* - def testClear(self): - g = ops.Graph() - a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - with g.control_dependencies([a_1]): - with g.control_dependencies([a_2]): - with g.control_dependencies(None): - with g.control_dependencies([a_3]): - with g.control_dependencies([a_4]): - # deps [a_3, a_4] - b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps = [a_3] - b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps back to None - b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps back to [a_1, a_2] - b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - # deps back to [a_1] - b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - with g.control_dependencies(None): - # deps are None again - b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) - self.assertItemsEqual([a_3.op], b_3.op.control_inputs) - self.assertItemsEqual([], b_none.op.control_inputs) - self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) - self.assertItemsEqual([a_1.op], b_1.op.control_inputs) - self.assertItemsEqual([], b_none2.op.control_inputs) - */ } - [Ignore("will fail due to unsupported op 'FloatOutput'")] [TestMethod] public void TestComplex() { - /* - def testComplex(self): - g = ops.Graph() - - # Usage pattern: - # * Nodes a_i are constants defined at the outermost scope, and are used - # as control inputs for the ith nested scope. - # * Nodes b_i are defined as Mul(a_3, a_4) at each scope. - # * Nodes c_i are defined as Mul(a_1, b_1) at each scope. - # * Nodes d_i are defined as Mul(b_i, c_i) at each scope. - # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. - - a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - - with g.control_dependencies([a_1]): - b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], - [dtypes.float32]) - c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], - [dtypes.float32]) - d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1], - [dtypes.float32]) - e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) - with g.control_dependencies([a_2]): - b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], - [dtypes.float32]) - c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], - [dtypes.float32]) - d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2], - [dtypes.float32]) - e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1], - [dtypes.float32]) - with g.control_dependencies([a_3]): - b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], - [dtypes.float32]) - c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], - [dtypes.float32]) - d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3], - [dtypes.float32]) - e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2], - [dtypes.float32]) - with g.control_dependencies([a_4]): - b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4], - [dtypes.float32]) - c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1], - [dtypes.float32]) - d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4], - [dtypes.float32]) - e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3], - [dtypes.float32]) + var g = tf.Graph().as_default(); + // Usage pattern: + // * Nodes a_i are constants defined at the outermost scope, and are used + // as control inputs for the ith nested scope. + // * Nodes b_i are defined as Mul(a_3, a_4) at each scope. + // * Nodes c_i are defined as Mul(a_1, b_1) at each scope. + // * Nodes d_i are defined as Mul(b_i, c_i) at each scope. + // * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1. + var a_1 = constant_op.constant(1.0); + var a_2 = constant_op.constant(2.0); + var a_3 = constant_op.constant(3.0); + var a_4 = constant_op.constant(4.0); + Operation b_1 = null, b_2 = null, b_3 = null, b_4 = null; + Operation c_1 = null, c_2 = null, c_3 = null, c_4 = null; + Operation d_1 = null, d_2 = null, d_3 = null, d_4 = null; + Operation e_1 = null, e_2 = null, e_3 = null, e_4 = null; + with(g.control_dependencies(new[] { a_1 }), ctrl1 => + { + b_1 = tf.multiply(a_3, a_4); + c_1 = tf.multiply(a_1, b_1.output); + d_1 = tf.multiply(b_1.output, c_1.output); + e_1 = constant_op.constant(5.0); + with(g.control_dependencies(new[] { a_2 }), ctrl2 => + { + b_2 = tf.multiply(a_3, a_4); + c_2 = tf.multiply(a_1, b_1.output); + d_2 = tf.multiply(b_2.output, c_2.output); + e_2 = tf.multiply(e_1.output, e_1.output); + with(g.control_dependencies(new[] { a_3 }), ctrl3 => + { + b_3 = tf.multiply(a_3, a_4); + c_3 = tf.multiply(a_1, b_1.output); + d_3 = tf.multiply(b_3.output, c_3.output); + e_3 = tf.multiply(e_2.output, e_2.output); + with(g.control_dependencies(new[] { a_4 }), ctrl4 => + { + b_4 = tf.multiply(a_3, a_4); + c_4 = tf.multiply(a_1, b_1.output); + d_4 = tf.multiply(b_4.output, c_4.output); + e_4 = tf.multiply(e_3.output, e_3.output); + }); + }); + }); + }); - self.assertItemsEqual([a_1.op], b_1.op.control_inputs) - self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs) - self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs) - self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs) + AssertItemsEqual(new[] {a_1.op}, b_1.op.control_inputs); + AssertItemsEqual(new[] {a_1.op, a_2.op}, b_2.op.control_inputs); + AssertItemsEqual(new[] { a_1.op, a_2.op}, b_3.op.control_inputs); + AssertItemsEqual(new[] {a_1.op, a_2.op}, b_4.op.control_inputs); - self.assertItemsEqual([], c_1.op.control_inputs) - self.assertItemsEqual([a_2.op], c_2.op.control_inputs) - self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs) - self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs) + AssertItemsEqual(new object[0], c_1.op.control_inputs); + AssertItemsEqual(new[] {a_2.op}, c_2.op.control_inputs); + AssertItemsEqual(new[] {a_2.op, a_3.op}, c_3.op.control_inputs); + AssertItemsEqual(new[] {a_2.op, a_3.op, a_4.op}, c_4.op.control_inputs); - self.assertItemsEqual([], d_1.op.control_inputs) - self.assertItemsEqual([], d_2.op.control_inputs) - self.assertItemsEqual([], d_3.op.control_inputs) - self.assertItemsEqual([], d_4.op.control_inputs) + AssertItemsEqual(new object[0], d_1.op.control_inputs); + AssertItemsEqual(new object[0], d_2.op.control_inputs); + AssertItemsEqual(new object[0], d_3.op.control_inputs); + AssertItemsEqual(new object[0], d_4.op.control_inputs); - self.assertItemsEqual([a_1.op], e_1.op.control_inputs) - self.assertItemsEqual([a_2.op], e_2.op.control_inputs) - self.assertItemsEqual([a_3.op], e_3.op.control_inputs) - self.assertItemsEqual([a_4.op], e_4.op.control_inputs) - */ + AssertItemsEqual(new[] {a_1.op}, e_1.op.control_inputs); + AssertItemsEqual(new[] {a_2.op}, e_2.op.control_inputs); + AssertItemsEqual(new[] {a_3.op}, e_3.op.control_inputs); + AssertItemsEqual(new[] {a_4.op}, e_4.op.control_inputs); } [Ignore("will fail due to unsupported op 'FloatOutput'")]