You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  5. {
  6. /// <summary>
  7. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  8. /// </summary>
  9. [TestClass]
  10. public class CondTestCases : PythonTest
  11. {
  12. [TestMethod]
  13. public void testCondTrue()
  14. {
  15. var graph = tf.Graph().as_default();
  16. with(tf.Session(graph), sess =>
  17. {
  18. var x = tf.constant(2);
  19. var y = tf.constant(5);
  20. var pred = tf.less(x, y);
  21. Func<ITensorOrOperation> if_true = delegate
  22. {
  23. return tf.multiply(x, 17);
  24. };
  25. Func<ITensorOrOperation> if_false = delegate
  26. {
  27. return tf.add(y, 23);
  28. };
  29. var z = control_flow_ops.cond(pred, if_true, if_false);
  30. int result = z.eval(sess);
  31. assertEquals(result, 34);
  32. });
  33. }
  34. [TestMethod]
  35. public void testCondFalse()
  36. {
  37. /* python
  38. * import tensorflow as tf
  39. from tensorflow.python.framework import ops
  40. def if_true():
  41. return tf.math.multiply(x, 17)
  42. def if_false():
  43. return tf.math.add(y, 23)
  44. with tf.Session() as sess:
  45. x = tf.constant(2)
  46. y = tf.constant(1)
  47. pred = tf.math.less(x,y)
  48. z = tf.cond(pred, if_true, if_false)
  49. result = z.eval()
  50. print(result == 24) */
  51. with(tf.Session(), sess =>
  52. {
  53. var x = tf.constant(2);
  54. var y = tf.constant(1);
  55. var pred = tf.less(x, y);
  56. Func<ITensorOrOperation> if_true = delegate
  57. {
  58. return tf.multiply(x, 17);
  59. };
  60. Func<ITensorOrOperation> if_false = delegate
  61. {
  62. return tf.add(y, 23);
  63. };
  64. var z = control_flow_ops.cond(pred, if_true, if_false);
  65. int result = z.eval(sess);
  66. assertEquals(result, 24);
  67. });
  68. }
  69. [Ignore("Todo")]
  70. [TestMethod]
  71. public void testCondTrueLegacy()
  72. {
  73. // def testCondTrueLegacy(self):
  74. // x = constant_op.constant(2)
  75. // y = constant_op.constant(5)
  76. // z = control_flow_ops.cond(
  77. // math_ops.less(x, y),
  78. // fn1=lambda: math_ops.multiply(x, 17),
  79. // fn2=lambda: math_ops.add(y, 23))
  80. // self.assertEquals(self.evaluate(z), 34)
  81. }
  82. [Ignore("Todo")]
  83. [TestMethod]
  84. public void testCondFalseLegacy()
  85. {
  86. // def testCondFalseLegacy(self):
  87. // x = constant_op.constant(2)
  88. // y = constant_op.constant(1)
  89. // z = control_flow_ops.cond(
  90. // math_ops.less(x, y),
  91. // fn1=lambda: math_ops.multiply(x, 17),
  92. // fn2=lambda: math_ops.add(y, 23))
  93. // self.assertEquals(self.evaluate(z), 24)
  94. }
  95. [Ignore("Todo")]
  96. [TestMethod]
  97. public void testCondMissingArg1()
  98. {
  99. // def testCondMissingArg1(self):
  100. // x = constant_op.constant(1)
  101. // with self.assertRaises(TypeError):
  102. // control_flow_ops.cond(True, false_fn=lambda: x)
  103. }
  104. [Ignore("Todo")]
  105. [TestMethod]
  106. public void testCondMissingArg2()
  107. {
  108. // def testCondMissingArg2(self):
  109. // x = constant_op.constant(1)
  110. // with self.assertRaises(TypeError):
  111. // control_flow_ops.cond(True, lambda: x)
  112. }
  113. [Ignore("Todo")]
  114. [TestMethod]
  115. public void testCondDuplicateArg1()
  116. {
  117. // def testCondDuplicateArg1(self):
  118. // x = constant_op.constant(1)
  119. // with self.assertRaises(TypeError):
  120. // control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
  121. }
  122. [Ignore("Todo")]
  123. [TestMethod]
  124. public void testCondDuplicateArg2()
  125. {
  126. // def testCondDuplicateArg2(self):
  127. // x = constant_op.constant(1)
  128. // with self.assertRaises(TypeError):
  129. // control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
  130. }
  131. }
  132. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。