You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GradientDescentOptimizerTests.cs 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Linq;
  4. using Tensorflow.NumPy;
  5. using TensorFlowNET.UnitTest;
  6. using static Tensorflow.Binding;
  7. namespace Tensorflow.Keras.UnitTest.Optimizers
  8. {
  9. [TestClass]
  10. public class GradientDescentOptimizerTest : PythonTest
  11. {
  12. private static TF_DataType GetTypeForNumericType<T>() where T : struct
  13. {
  14. return Type.GetTypeCode(typeof(T)) switch
  15. {
  16. TypeCode.Single => np.float32,
  17. TypeCode.Double => np.float64,
  18. _ => throw new NotImplementedException(),
  19. };
  20. }
  21. private void TestBasicGeneric<T>() where T : struct
  22. {
  23. var dtype = GetTypeForNumericType<T>();
  24. // train.GradientDescentOptimizer is V1 only API.
  25. tf.Graph().as_default();
  26. using (var sess = self.cached_session())
  27. {
  28. var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
  29. var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
  30. var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
  31. var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
  32. var optimizer = tf.train.GradientDescentOptimizer(3.0f);
  33. var grads_and_vars = new[] {
  34. Tuple.Create(grads0, var0 as IVariableV1),
  35. Tuple.Create(grads1, var1 as IVariableV1)
  36. };
  37. var sgd_op = optimizer.apply_gradients(grads_and_vars);
  38. var global_variables = tf.global_variables_initializer();
  39. sess.run(global_variables);
  40. // Fetch params to validate initial values
  41. var initialVar0 = sess.run(var0);
  42. var valu = var0.eval(sess);
  43. var initialVar1 = sess.run(var1);
  44. // TODO: use self.evaluate<T[]> instead of self.evaluate<double[]>
  45. self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
  46. self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
  47. // Run 1 step of sgd
  48. sgd_op.run();
  49. // Validate updated params
  50. self.assertAllCloseAccordingToType(
  51. new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 },
  52. self.evaluate<T[]>(var0));
  53. self.assertAllCloseAccordingToType(
  54. new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 },
  55. self.evaluate<T[]>(var1));
  56. // TODO: self.assertEqual(0, len(optimizer.variables()));
  57. }
  58. }
  59. [TestMethod]
  60. public void TestBasic()
  61. {
  62. //TODO: add np.half
  63. TestBasicGeneric<float>();
  64. TestBasicGeneric<double>();
  65. }
  66. }
  67. }