You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CreateOpFromTfOperationTest.cs 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Microsoft.VisualStudio.TestTools.UnitTesting;
  6. using Tensorflow;
  7. using Tensorflow.Operations;
  8. namespace TensorFlowNET.UnitTest.ops_test
  9. {
  10. /// <summary>
  11. /// excerpt of tensorflow/python/framework/ops_test.py
  12. /// # These cases test the private Graph._create_op_from_tf_operation
  13. /// # method. Arguably we should only test the public APIs that depend on this
  14. /// # method. However, this logic is complex and tricky, and it can be difficult to
  15. /// # ascertain if we have adequate coverage (e.g. a graph may run successfully if
  16. /// # the control flow context isn't set properly, but a more complicated use case
  17. /// # that might not be obvious to test will fail). Thus we instead explicitly test
  18. /// # the low-level behavior.
  19. /// </summary>
  20. [TestClass]
  21. public class CreateOpFromTfOperationTest : PythonTest
  22. {
  23. [TestMethod]
  24. public void TestShape()
  25. {
  26. var graph = tf.Graph().as_default();
  27. with<Graph>(graph, g =>
  28. {
  29. var x = constant_op.constant(new[,] { { 1, 2, 3 }, { 4, 5, 6 } });
  30. var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] { x }, new Operation[0]);
  31. var op = g._create_op_from_tf_operation(c_op);
  32. Assert.AreEqual("myop", op.name);
  33. Assert.AreEqual("Identity", op.type);
  34. Assert.AreEqual(1, len(op.outputs));
  35. assertItemsEqual(new[] { 2, 3 }, op.outputs[0].shape);
  36. });
  37. }
  38. [TestMethod]
  39. public void TestUniqueName()
  40. {
  41. var graph = tf.Graph().as_default();
  42. with<Graph>(graph, g =>
  43. {
  44. //var (c_op,op_desc) = ops._create_c_op(g, ops._NodeDef("Const", "myop"), new Tensor[0], new Operation[0]);
  45. //var (c_op2, op_desc1) = ops._create_c_op(g, ops._NodeDef("Const", "myop_1"), new Tensor[0], new Operation[0]);
  46. //var op = g._create_op_from_tf_operation(c_op);
  47. //var op2 = g._create_op_from_tf_operation(c_op2);
  48. var op = constant_op.constant(0, name: "myop").op;
  49. var op2 = constant_op.constant(0, name: "myop_1").op;
  50. // Create ops with same names as op1 and op2. We expect the new names to be
  51. // uniquified.
  52. var op3 = constant_op.constant(0, name: "myop").op;
  53. var op4 = constant_op.constant(0, name: "myop_1").op;
  54. self.assertEqual(op.name, "myop");
  55. self.assertEqual(op2.name, "myop_1");
  56. self.assertEqual(op3.name, "myop_2");
  57. self.assertEqual(op4.name, "myop_1_1");
  58. });
  59. }
  60. [Ignore("Switch op gets not inserted correctly in the graph")]
  61. [TestMethod]
  62. public void TestCond()
  63. {
  64. var graph = tf.Graph().as_default();
  65. with<Graph>(graph, g =>
  66. {
  67. var x = constant_op.constant(10);
  68. var true_fn = new Func<Tensor>(() =>
  69. {
  70. var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]);
  71. var new_ops = g._add_new_tf_operations();
  72. self.assertEqual(len(new_ops), 1);
  73. return x;
  74. });
  75. control_flow_ops.cond(x < 10, true_fn, () => x);
  76. var op = g.get_operation_by_name("cond/myop");
  77. tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true);
  78. tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
  79. self.assertIsNotNone(op);
  80. self.assertEqual(op.name, "cond/myop");
  81. self.assertEqual(op.type, "Identity");
  82. //self.assertEqual(op.outputs, new object[0]);
  83. var op_input = op.inputs[0].op;
  84. self.assertEqual(op_input.type, "Switch");
  85. self.assertEqual(op_input.inputs[0], x);
  86. self.assertEqual(op.graph, g);
  87. self.assertIsNotNone(op._get_control_flow_context());
  88. self.assertEqual((op._get_control_flow_context() as ControlFlowContext).name, "cond/cond_text");
  89. });
  90. /*
  91. @test_util.run_v1_only("b/120545219")
  92. def testCond(self):
  93. g = ops.Graph()
  94. with g.as_default():
  95. x = test_ops.int_output()
  96. def true_fn():
  97. ops._create_c_op(ops.get_default_graph(),
  98. ops._NodeDef("IntInput", "cond/myop"), [x], [])
  99. new_ops = g._add_new_tf_operations()
  100. self.assertEqual(len(new_ops), 1)
  101. return x
  102. control_flow_ops.cond(x < 10, true_fn, lambda: x)
  103. op = g.get_operation_by_name("cond/myop")
  104. self.assertIsNotNone(op)
  105. self.assertEqual(op.name, "cond/myop")
  106. self.assertEqual(op.type, "IntInput")
  107. self.assertEqual(op.outputs, [])
  108. op_input = op.inputs[0].op
  109. self.assertEqual(op_input.type, "Switch")
  110. self.assertEqual(op_input.inputs[0], x)
  111. self.assertEqual(op.graph, g)
  112. # pylint: disable=protected-access
  113. self.assertIsNotNone(op._get_control_flow_context())
  114. self.assertEqual(op._get_control_flow_context().name,
  115. "cond/cond_text")
  116. # pylint: enable=protected-access
  117. */
  118. }
  119. [Ignore("Todo: Port")]
  120. [TestMethod]
  121. public void TestWhileLoop()
  122. {
  123. var graph = tf.Graph().as_default();
  124. Operation x=null;
  125. with<Graph>(graph, g =>
  126. {
  127. x = constant_op.constant(42);
  128. var body = new Func<int, int>(i =>
  129. {
  130. ops._create_c_op(ops.get_default_graph(), ops._NodeDef("Identity", "myloop/myop"), new[] {x},
  131. new Operation[0]);
  132. var new_ops = g._add_new_tf_operations();
  133. self.assertEqual(len(new_ops), 1);
  134. return i;
  135. });
  136. // TODO: port control_flow_ops.while_loop
  137. //control_flow_ops.while_loop( i => i < 10, body, new int[]{0}, name = "myloop");
  138. });
  139. var op = graph.get_operation_by_name("myloop/myop");
  140. self.assertIsNotNone(op);
  141. self.assertEqual(op.name, "myloop/myop");
  142. self.assertEqual(op.type, "Identity");
  143. self.assertEqual(op.outputs.Length, 0);
  144. var op_input = op.inputs[0].op;
  145. self.assertEqual(op_input.type, "Enter");
  146. self.assertItemsEqual(op_input.inputs.OfType<Operation>().ToArray(), new[] {x});
  147. self.assertEqual(op.graph, graph);
  148. self.assertIsNotNone(op._get_control_flow_context());
  149. self.assertEqual(((ControlFlowContext)op._get_control_flow_context()).name, "myloop/while_context");
  150. /*
  151. @test_util.run_v1_only("b/120545219")
  152. def testWhileLoop(self):
  153. g = ops.Graph()
  154. with g.as_default():
  155. x = test_ops.int_output()
  156. def body(i):
  157. ops._create_c_op(ops.get_default_graph(),
  158. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  159. new_ops = g._add_new_tf_operations()
  160. self.assertEqual(len(new_ops), 1)
  161. return i
  162. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  163. op = g.get_operation_by_name("myloop/myop")
  164. self.assertIsNotNone(op)
  165. self.assertEqual(op.name, "myloop/myop")
  166. self.assertEqual(op.type, "IntInput")
  167. self.assertEqual(op.outputs, [])
  168. op_input = op.inputs[0].op
  169. self.assertEqual(op_input.type, "Enter")
  170. self.assertEqual(list(op_input.inputs), [x])
  171. self.assertEqual(op.graph, g)
  172. # pylint: disable=protected-access
  173. self.assertIsNotNone(op._get_control_flow_context())
  174. self.assertEqual(op._get_control_flow_context().name,
  175. "myloop/while_context")
  176. # pylint: enable=protected-access
  177. */
  178. }
  179. [Ignore("Todo: Port")]
  180. [TestMethod]
  181. public void TestWhileLoopWithInternalControlDep()
  182. {
  183. /*
  184. @test_util.run_v1_only("b/120545219")
  185. def testWhileLoopWithInternalControlDep(self):
  186. g = ops.Graph()
  187. with g.as_default():
  188. x = test_ops.int_output()
  189. def body(i):
  190. c = constant_op.constant(1.0, name="c")
  191. ops._create_c_op(ops.get_default_graph(),
  192. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  193. with ops.control_dependencies([c]):
  194. new_ops = g._add_new_tf_operations()
  195. self.assertEqual(len(new_ops), 1)
  196. return i
  197. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  198. op = g.get_operation_by_name("myloop/myop")
  199. self.assertIsNotNone(op)
  200. c = g.get_operation_by_name("myloop/c")
  201. self.assertIsNotNone(c)
  202. # Internal control dep is preserved
  203. self.assertEqual(op.control_inputs, [c])
  204. */
  205. }
  206. [Ignore("Todo: Port")]
  207. [TestMethod]
  208. public void TestWhileLoopWithExternalControlDep()
  209. {
  210. /*
  211. @test_util.run_v1_only("b/120545219")
  212. def testWhileLoopWithExternalControlDep(self):
  213. g = ops.Graph()
  214. with g.as_default():
  215. x = test_ops.int_output()
  216. c = constant_op.constant(1.0)
  217. def body(i):
  218. ops._create_c_op(ops.get_default_graph(),
  219. ops._NodeDef("IntInput", "myloop/myop"), [x], [])
  220. with ops.control_dependencies([c]):
  221. new_ops = g._add_new_tf_operations()
  222. self.assertEqual(len(new_ops), 1)
  223. return i
  224. control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="myloop")
  225. op = g.get_operation_by_name("myloop/myop")
  226. self.assertIsNotNone(op)
  227. # External control dep is removed and replaced with internal control dep
  228. self.assertNotEqual(op.control_inputs[0], c.op)
  229. self.assertIsNotNone(op.control_inputs[0]._get_control_flow_context())
  230. */
  231. }
  232. }
  233. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。