You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CreateOpFromTfOperationTest.cs 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Microsoft.VisualStudio.TestTools.UnitTesting;
  5. using Tensorflow;
  6. namespace TensorFlowNET.UnitTest
  7. {
  8. /// <summary>
  9. /// excerpt of tensorflow/python/framework/ops_test.py
  10. /// # These cases test the private Graph._create_op_from_tf_operation
  11. /// # method. Arguably we should only test the public APIs that depend on this
  12. /// # method. However, this logic is complex and tricky, and it can be difficult to
  13. /// # ascertain if we have adequate coverage (e.g. a graph may run successfully if
  14. /// # the control flow context isn't set properly, but a more complicated use case
  15. /// # that might not be obvious to test will fail). Thus we instead explicitly test
  16. /// # the low-level behavior.
  17. /// </summary>
  18. [TestClass]
  19. public class CreateOpFromTfOperationTest : PythonTest
  20. {
  21. [TestMethod]
  22. public void TestShape()
  23. {
  24. var graph = tf.Graph().as_default();
  25. with<Graph>(graph, g =>
  26. {
  27. var x = constant_op.constant(new [,] { {1, 2, 3}, {4, 5, 6}});
  28. var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]);
  29. var op = g._create_op_from_tf_operation(c_op);
  30. Assert.AreEqual("myop", op.name);
  31. Assert.AreEqual("Identity", op.type);
  32. Assert.AreEqual(1, len(op.outputs));
  33. AssertItemsEqual(new []{2, 3}, op.outputs[0].shape);
  34. });
  35. }
  36. /*def testUniqueName(self):
  37. g = ops.Graph()
  38. with g.as_default():
  39. c_op = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop"), [], [])
  40. c_op2 = ops._create_c_op(g, ops._NodeDef("IntOutput", "myop_1"), [], [])
  41. op = g._create_op_from_tf_operation(c_op)
  42. op2 = g._create_op_from_tf_operation(c_op2)
  43. # Create ops with same names as op1 and op2. We expect the new names to be
  44. # uniquified.
  45. op3 = test_ops.int_output(name="myop").op
  46. op4 = test_ops.int_output(name="myop_1").op
  47. self.assertEqual(op.name, "myop")
  48. self.assertEqual(op2.name, "myop_1")
  49. self.assertEqual(op3.name, "myop_2")
  50. self.assertEqual(op4.name, "myop_1_1")
  51. @test_util.run_v1_only("b/120545219")
  52. def testCond(self):
  53. g = ops.Graph()
  54. with g.as_default():
  55. x = test_ops.int_output()
  56. def true_fn():
  57. ops._create_c_op(ops.get_default_graph(),
  58. ops._NodeDef("IntInput", "cond/myop"), [x], [])
  59. new_ops = g._add_new_tf_operations()
  60. self.assertEqual(len(new_ops), 1)
  61. return x
  62. control_flow_ops.cond(x < 10, true_fn, lambda: x)
  63. op = g.get_operation_by_name("cond/myop")
  64. self.assertIsNotNone(op)
  65. self.assertEqual(op.name, "cond/myop")
  66. self.assertEqual(op.type, "IntInput")
  67. self.assertEqual(op.outputs, [])
  68. op_input = op.inputs[0].op
  69. self.assertEqual(op_input.type, "Switch")
  70. self.assertEqual(op_input.inputs[0], x)
  71. self.assertEqual(op.graph, g)
  72. # pylint: disable=protected-access
  73. self.assertIsNotNone(op._get_control_flow_context())
  74. self.assertEqual(op._get_control_flow_context().name,
  75. "cond/cond_text")
  76. # pylint: enable=protected-access
  77. @test_util.run_v1_only("b/120545219")
  78. def testWhileLoop(self):
  79. g = ops.Graph()
  80. with g.as_default():
  81. x = test_ops.int_output()
  82. def body(i):
  83. ops._create_c_op(ops.get_default_graph(),
  84. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  85. new_ops = g._add_new_tf_operations()
  86. self.assertEqual(len(new_ops), 1)
  87. return i
  88. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  89. op = g.get_operation_by_name("myloop/myop")
  90. self.assertIsNotNone(op)
  91. self.assertEqual(op.name, "myloop/myop")
  92. self.assertEqual(op.type, "IntInput")
  93. self.assertEqual(op.outputs, [])
  94. op_input = op.inputs[0].op
  95. self.assertEqual(op_input.type, "Enter")
  96. self.assertEqual(list(op_input.inputs), [x])
  97. self.assertEqual(op.graph, g)
  98. # pylint: disable=protected-access
  99. self.assertIsNotNone(op._get_control_flow_context())
  100. self.assertEqual(op._get_control_flow_context().name,
  101. "myloop/while_context")
  102. # pylint: enable=protected-access
  103. @test_util.run_v1_only("b/120545219")
  104. def testWhileLoopWithInternalControlDep(self):
  105. g = ops.Graph()
  106. with g.as_default():
  107. x = test_ops.int_output()
  108. def body(i):
  109. c = constant_op.constant(1.0, name="c")
  110. ops._create_c_op(ops.get_default_graph(),
  111. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  112. with ops.control_dependencies([c]):
  113. new_ops = g._add_new_tf_operations()
  114. self.assertEqual(len(new_ops), 1)
  115. return i
  116. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  117. op = g.get_operation_by_name("myloop/myop")
  118. self.assertIsNotNone(op)
  119. c = g.get_operation_by_name("myloop/c")
  120. self.assertIsNotNone(c)
  121. # Internal control dep is preserved
  122. self.assertEqual(op.control_inputs, [c])
  123. @test_util.run_v1_only("b/120545219")
  124. def testWhileLoopWithExternalControlDep(self):
  125. g = ops.Graph()
  126. with g.as_default():
  127. x = test_ops.int_output()
  128. c = constant_op.constant(1.0)
  129. def body(i):
  130. ops._create_c_op(ops.get_default_graph(),
  131. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  132. with ops.control_dependencies([c]):
  133. new_ops = g._add_new_tf_operations()
  134. self.assertEqual(len(new_ops), 1)
  135. return i
  136. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  137. op = g.get_operation_by_name("myloop/myop")
  138. self.assertIsNotNone(op)
  139. # External control dep is removed and replaced with internal control dep
  140. self.assertNotEqual(op.control_inputs[0], c.op)
  141. self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
  142. */
  143. }
  144. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。