You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow;
  3. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  4. {
  5. /// <summary>
  6. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  7. /// </summary>
  8. [TestClass]
  9. public class CondTestCases : PythonTest
  10. {
  11. [TestMethod]
  12. public void testCondTrue()
  13. {
  14. with(tf.Graph().as_default(), g =>
  15. {
  16. var x = tf.constant(2);
  17. var y = tf.constant(5);
  18. var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
  19. () => tf.add(y, tf.constant(23)));
  20. //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false);
  21. self.assertEquals(eval_scalar(z), 34);
  22. });
  23. }
  24. [Ignore("This Test Fails due to missing edges in the graph!")]
  25. [TestMethod]
  26. public void testCondFalse()
  27. {
  28. with(tf.Graph().as_default(), g =>
  29. {
  30. var x = tf.constant(2);
  31. var y = tf.constant(1);
  32. var z = control_flow_ops.cond(tf.less(x, y), () => tf.multiply(x, tf.constant(17)),
  33. () => tf.add(y, tf.constant(23)));
  34. self.assertEquals(eval_scalar(z), 24);
  35. });
  36. }
  37. [Ignore("Todo")]
  38. [TestMethod]
  39. public void testCondMissingArg1()
  40. {
  41. // def testCondMissingArg1(self):
  42. // x = constant_op.constant(1)
  43. // with self.assertRaises(TypeError):
  44. // control_flow_ops.cond(True, false_fn=lambda: x)
  45. }
  46. [Ignore("Todo")]
  47. [TestMethod]
  48. public void testCondMissingArg2()
  49. {
  50. // def testCondMissingArg2(self):
  51. // x = constant_op.constant(1)
  52. // with self.assertRaises(TypeError):
  53. // control_flow_ops.cond(True, lambda: x)
  54. }
  55. [Ignore("Todo")]
  56. [TestMethod]
  57. public void testCondDuplicateArg1()
  58. {
  59. // def testCondDuplicateArg1(self):
  60. // x = constant_op.constant(1)
  61. // with self.assertRaises(TypeError):
  62. // control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x)
  63. }
  64. [Ignore("Todo")]
  65. [TestMethod]
  66. public void testCondDuplicateArg2()
  67. {
  68. // def testCondDuplicateArg2(self):
  69. // x = constant_op.constant(1)
  70. // with self.assertRaises(TypeError):
  71. // control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x)
  72. }
  73. }
  74. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。