You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CondTestCases.cs 2.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. namespace TensorFlowNET.UnitTest.control_flow_ops_test
  5. {
  6. /// <summary>
  7. /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py
  8. /// </summary>
  9. [TestClass]
  10. public class CondTestCases : PythonTest
  11. {
  12. [TestMethod]
  13. public void testCondTrue()
  14. {
  15. var graph = tf.Graph().as_default();
  16. with(tf.Session(graph), sess =>
  17. {
  18. var x = tf.constant(2);
  19. var y = tf.constant(5);
  20. var pred = tf.less(x, y);
  21. Func<ITensorOrOperation> if_true = delegate
  22. {
  23. return tf.multiply(x, 17);
  24. };
  25. Func<ITensorOrOperation> if_false = delegate
  26. {
  27. return tf.add(y, 23);
  28. };
  29. var z = control_flow_ops.cond(pred, if_true, if_false);
  30. int result = z.eval(sess);
  31. assertEquals(result, 34);
  32. });
  33. }
  34. [TestMethod]
  35. public void testCondFalse()
  36. {
  37. /* python
  38. * import tensorflow as tf
  39. from tensorflow.python.framework import ops
  40. def if_true():
  41. return tf.math.multiply(x, 17)
  42. def if_false():
  43. return tf.math.add(y, 23)
  44. with tf.Session() as sess:
  45. x = tf.constant(2)
  46. y = tf.constant(1)
  47. pred = tf.math.less(x,y)
  48. z = tf.cond(pred, if_true, if_false)
  49. result = z.eval()
  50. print(result == 24) */
  51. with(tf.Session(), sess =>
  52. {
  53. var x = tf.constant(2);
  54. var y = tf.constant(1);
  55. var pred = tf.less(x, y);
  56. Func<ITensorOrOperation> if_true = delegate
  57. {
  58. return tf.multiply(x, 17);
  59. };
  60. Func<ITensorOrOperation> if_false = delegate
  61. {
  62. return tf.add(y, 23);
  63. };
  64. var z = control_flow_ops.cond(pred, if_true, if_false);
  65. int result = z.eval(sess);
  66. assertEquals(result, 24);
  67. });
  68. }
  69. // NOTE: all other test python test cases of this class are either not needed due to strong typing or dest a deprecated api
  70. }
  71. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。