You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientDescentOptimizerTests.cs 2.8 kB

2 years ago
2 years ago
2 years ago
2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. using Microsoft.VisualStudio.TestPlatform.Utilities;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using System;
  4. using System.Linq;
  5. using Tensorflow.NumPy;
  6. using TensorFlowNET.UnitTest;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow.Keras.UnitTest.Optimizers
  9. {
  10. [TestClass]
  11. public class GradientDescentOptimizerTest : PythonTest
  12. {
  13. private static TF_DataType GetTypeForNumericType<T>() where T : struct
  14. {
  15. return Type.GetTypeCode(typeof(T)) switch
  16. {
  17. TypeCode.Single => np.float32,
  18. TypeCode.Double => np.float64,
  19. _ => throw new NotImplementedException(),
  20. };
  21. }
  22. private void TestBasic<T>() where T : struct
  23. {
  24. var dtype = GetTypeForNumericType<T>();
  25. // train.GradientDescentOptimizer is V1 only API.
  26. tf.Graph().as_default();
  27. using (var sess = self.cached_session())
  28. {
  29. var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
  30. var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
  31. var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
  32. var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
  33. var optimizer = tf.train.GradientDescentOptimizer(3.0f);
  34. var grads_and_vars = new[] {
  35. Tuple.Create(grads0, var0 as IVariableV1),
  36. Tuple.Create(grads1, var1 as IVariableV1)
  37. };
  38. var sgd_op = optimizer.apply_gradients(grads_and_vars);
  39. var global_variables = tf.global_variables_initializer();
  40. sess.run(global_variables);
  41. var initialVar0 = sess.run(var0);
  42. var initialVar1 = sess.run(var1);
  43. // Fetch params to validate initial values
  44. self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
  45. self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
  46. // Run 1 step of sgd
  47. sgd_op.run();
  48. // Validate updated params
  49. self.assertAllCloseAccordingToType(
  50. new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
  51. self.evaluate<T[]>(var0));
  52. self.assertAllCloseAccordingToType(
  53. new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
  54. self.evaluate<T[]>(var1));
  55. // TODO: self.assertEqual(0, len(optimizer.variables()));
  56. }
  57. }
  58. [TestMethod]
  59. public void TestBasic()
  60. {
  61. //TODO: add np.half
  62. TestBasic<float>();
  63. TestBasic<double>();
  64. }
  65. }
  66. }