You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ControlDependenciesTest.cs 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Microsoft.VisualStudio.TestTools.UnitTesting;
  6. using Tensorflow;
  7. using Tensorflow.Eager;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. /// <summary>
  11. /// excerpt of tensorflow/python/framework/ops_test.py
  12. /// </summary>
  13. [TestClass]
  14. public class ControlDependenciesTest : Python
  15. {
  16. [TestMethod]
  17. public void TestBasic()
  18. {
  19. var graph = tf.Graph().as_default();
  20. Tensor a = null, b = null, c = null, d = null, e = null;
  21. with<Graph>(graph, g =>
  22. {
  23. a = constant_op.constant(1.0);
  24. b = constant_op.constant(1.0);
  25. with(g.control_dependencies(new ITensorOrOperation[] { a }), x =>
  26. {
  27. c = constant_op.constant(1.0);
  28. d = array_ops.identity(b);
  29. e = array_ops.identity(c);
  30. });
  31. });
  32. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] { a.op }));
  33. Assert.IsTrue(Enumerable.SequenceEqual(d.op.control_inputs, new[] { a.op }));
  34. // e should be dominated by c.
  35. Assert.AreEqual(0, e.op.control_inputs.Length);
  36. }
  37. [Ignore("Part of this test is not compiling")]
  38. [TestMethod]
  39. public void TestEager()
  40. {
  41. Tensor a = null, b = null, c = null, d = null, e = null;
  42. var calls = 0;
  43. Func<Tensor> future = () =>
  44. {
  45. calls += 1;
  46. return constant_op.constant(2.0);
  47. };
  48. using (var opts = new ContextOptions())
  49. using (var status = new Status())
  50. using (var context = new Context(opts, status))
  51. {
  52. if (context.executing_eagerly())
  53. {
  54. // TODO: make this compile (see original Python code below)
  55. //a = constant_op.constant(1.0);
  56. //b = future; // <--- {henon} obviously, this doesn't compile, looks like control_dependencies needs to be able to take callables as well.
  57. //with(ops.control_dependencies(new Operation[] {a, b}), ctrl =>
  58. //{
  59. // return c = constant_op.constant(3.0);
  60. //});
  61. //Assert.AreEqual(calls, 1);
  62. }
  63. else
  64. {
  65. var graph = tf.Graph();
  66. with<Graph>(graph.as_default(), g =>
  67. {
  68. a = constant_op.constant(1.0);
  69. b = future();
  70. with(g.control_dependencies(new ITensorOrOperation[] {a, b}), ctrl =>
  71. {
  72. c = constant_op.constant(3.0);
  73. });
  74. Assert.IsTrue(Enumerable.SequenceEqual(c.op.control_inputs, new[] {a.op, b.op}));
  75. Assert.AreEqual(1, calls);
  76. });
  77. }
  78. }
  79. /*
  80. def testEager(self):
  81. def future():
  82. future.calls += 1
  83. return constant_op.constant(2.0)
  84. future.calls = 0
  85. if context.executing_eagerly():
  86. a = constant_op.constant(1.0)
  87. b = future
  88. with ops.control_dependencies([a, b]):
  89. c = constant_op.constant(3.0)
  90. self.assertEqual(future.calls, 1)
  91. else:
  92. g = ops.Graph()
  93. with g.as_default():
  94. a = constant_op.constant(1.0)
  95. b = future()
  96. with g.control_dependencies([a, b]):
  97. c = constant_op.constant(3.0)
  98. self.assertEqual(c.op.control_inputs, [a.op, b.op])
  99. self.assertEqual(future.calls, 1)
  100. */
  101. }
  102. [Ignore("How to translate _apply_op into c#?")]
  103. [TestMethod]
  104. public void TestBasicWithConversion()
  105. {
  106. /*
  107. def testBasicWithConversion(self):
  108. g = ops.Graph()
  109. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  110. class ConvertibleObj(object):
  111. def _as_graph_element(self):
  112. return a
  113. with g.control_dependencies([ConvertibleObj()]):
  114. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  115. self.assertEqual(c.op.control_inputs, [a.op])
  116. */
  117. }
  118. [Ignore("How to translate _apply_op into c#?")]
  119. [TestMethod]
  120. public void TestNested()
  121. {
  122. /*
  123. def testNested(self):
  124. g = ops.Graph()
  125. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  126. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  127. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  128. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  129. with g.control_dependencies([a_1, a_2, a_3, a_4]):
  130. b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  131. with g.control_dependencies([a_1]):
  132. with g.control_dependencies([a_2]):
  133. with g.control_dependencies([a_3]):
  134. with g.control_dependencies([a_4]):
  135. b_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  136. self.assertItemsEqual([a_1.op, a_2.op, a_3.op, a_4.op],
  137. b_1.op.control_inputs)
  138. self.assertItemsEqual(b_1.op.control_inputs, b_2.op.control_inputs)
  139. */
  140. }
  141. [Ignore("How to translate _apply_op into c#?")]
  142. [TestMethod]
  143. public void TestClear()
  144. {
  145. /*
  146. def testClear(self):
  147. g = ops.Graph()
  148. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  149. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  150. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  151. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  152. with g.control_dependencies([a_1]):
  153. with g.control_dependencies([a_2]):
  154. with g.control_dependencies(None):
  155. with g.control_dependencies([a_3]):
  156. with g.control_dependencies([a_4]):
  157. # deps [a_3, a_4]
  158. b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  159. # deps = [a_3]
  160. b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  161. # deps back to None
  162. b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  163. # deps back to [a_1, a_2]
  164. b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  165. # deps back to [a_1]
  166. b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  167. with g.control_dependencies(None):
  168. # deps are None again
  169. b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  170. self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
  171. self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
  172. self.assertItemsEqual([], b_none.op.control_inputs)
  173. self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
  174. self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
  175. self.assertItemsEqual([], b_none2.op.control_inputs)
  176. */
  177. }
  178. [Ignore("How to translate _apply_op into c#?")]
  179. [TestMethod]
  180. public void TestComplex()
  181. {
  182. /*
  183. def testComplex(self):
  184. g = ops.Graph()
  185. # Usage pattern:
  186. # * Nodes a_i are constants defined at the outermost scope, and are used
  187. # as control inputs for the ith nested scope.
  188. # * Nodes b_i are defined as Mul(a_3, a_4) at each scope.
  189. # * Nodes c_i are defined as Mul(a_1, b_1) at each scope.
  190. # * Nodes d_i are defined as Mul(b_i, c_i) at each scope.
  191. # * Nodes e_i are defined as Mul(e_i-1, e_i-1) at each scope i > 1.
  192. a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  193. a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  194. a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  195. a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  196. with g.control_dependencies([a_1]):
  197. b_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  198. [dtypes.float32])
  199. c_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  200. [dtypes.float32])
  201. d_1 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_1, c_1],
  202. [dtypes.float32])
  203. e_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  204. with g.control_dependencies([a_2]):
  205. b_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  206. [dtypes.float32])
  207. c_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  208. [dtypes.float32])
  209. d_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_2, c_2],
  210. [dtypes.float32])
  211. e_2 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_1, e_1],
  212. [dtypes.float32])
  213. with g.control_dependencies([a_3]):
  214. b_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  215. [dtypes.float32])
  216. c_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  217. [dtypes.float32])
  218. d_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_3, c_3],
  219. [dtypes.float32])
  220. e_3 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_2, e_2],
  221. [dtypes.float32])
  222. with g.control_dependencies([a_4]):
  223. b_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_3, a_4],
  224. [dtypes.float32])
  225. c_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [a_1, b_1],
  226. [dtypes.float32])
  227. d_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [b_4, c_4],
  228. [dtypes.float32])
  229. e_4 = _apply_op(g, "TwoFloatInputsFloatOutput", [e_3, e_3],
  230. [dtypes.float32])
  231. self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
  232. self.assertItemsEqual([a_1.op, a_2.op], b_2.op.control_inputs)
  233. self.assertItemsEqual([a_1.op, a_2.op], b_3.op.control_inputs)
  234. self.assertItemsEqual([a_1.op, a_2.op], b_4.op.control_inputs)
  235. self.assertItemsEqual([], c_1.op.control_inputs)
  236. self.assertItemsEqual([a_2.op], c_2.op.control_inputs)
  237. self.assertItemsEqual([a_2.op, a_3.op], c_3.op.control_inputs)
  238. self.assertItemsEqual([a_2.op, a_3.op, a_4.op], c_4.op.control_inputs)
  239. self.assertItemsEqual([], d_1.op.control_inputs)
  240. self.assertItemsEqual([], d_2.op.control_inputs)
  241. self.assertItemsEqual([], d_3.op.control_inputs)
  242. self.assertItemsEqual([], d_4.op.control_inputs)
  243. self.assertItemsEqual([a_1.op], e_1.op.control_inputs)
  244. self.assertItemsEqual([a_2.op], e_2.op.control_inputs)
  245. self.assertItemsEqual([a_3.op], e_3.op.control_inputs)
  246. self.assertItemsEqual([a_4.op], e_4.op.control_inputs)
  247. */
  248. }
  249. [Ignore("How to translate _apply_op into c#?")]
  250. [TestMethod]
  251. public void TestRepeatedDependency()
  252. {
  253. /*
  254. def testRepeatedDependency(self):
  255. g = ops.Graph()
  256. a = g.create_op("TwoFloatOutputs", [], [dtypes.float32, dtypes.float32])
  257. a_0, a_1 = a.outputs
  258. with g.control_dependencies([a_0]):
  259. b = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  260. with g.control_dependencies([a_1]):
  261. c = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  262. self.assertEqual(b.op.control_inputs, [a])
  263. self.assertEqual(c.op.control_inputs, [a])
  264. def testNoControlDependencyWithDataDependency(self):
  265. g = ops.Graph()
  266. a = _apply_op(g, "FloatOutput", [], [dtypes.float32])
  267. with g.control_dependencies([a]):
  268. b = _apply_op(g, "Identity", [a], [dtypes.float32])
  269. self.assertEqual(b.op.control_inputs, [])
  270. */
  271. }
  272. }
  273. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。