You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ExponentialMovingAverage.cs 2.7 kB

6 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using static Tensorflow.Binding;
  6. namespace Tensorflow.Train
  7. {
  8. public class ExponentialMovingAverage
  9. {
  10. float _decay;
  11. int? _num_updates;
  12. bool _zero_debias;
  13. string _name;
  14. public string name => _name;
  15. Dictionary<RefVariable, RefVariable> _averages;
  16. public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false,
  17. string name = "ExponentialMovingAverage")
  18. {
  19. _decay = decay;
  20. _num_updates = num_updates;
  21. _zero_debias = zero_debias;
  22. _name = name;
  23. _averages = new Dictionary<RefVariable, RefVariable>();
  24. }
  25. /// <summary>
  26. /// Maintains moving averages of variables.
  27. /// </summary>
  28. /// <param name="var_list"></param>
  29. /// <returns></returns>
  30. public Operation apply(RefVariable[] var_list = null)
  31. {
  32. if (var_list == null)
  33. var_list = variables.trainable_variables() as RefVariable[];
  34. foreach(var var in var_list)
  35. {
  36. if (!_averages.ContainsKey(var))
  37. {
  38. ops.init_scope();
  39. var slot_creator = new SlotCreator();
  40. var value = var.initialized_value();
  41. var avg = slot_creator.create_slot(var,
  42. value,
  43. name,
  44. colocate_with_primary: true);
  45. ops.add_to_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES, var);
  46. _averages[var] = avg;
  47. }
  48. else
  49. {
  50. // avg = slot_creator.create_zeros_slot(
  51. throw new NotImplementedException("");
  52. }
  53. }
  54. return tf_with(ops.name_scope(name), scope =>
  55. {
  56. var decay = ops.convert_to_tensor(_decay, name: "decay");
  57. if (_num_updates.HasValue)
  58. {
  59. throw new NotImplementedException("ExponentialMovingAverage.apply");
  60. }
  61. var updates = new List<Tensor>();
  62. foreach (var var in var_list)
  63. {
  64. var zero_debias = false;// _averages[var] in zero_debias_true
  65. var ama = moving_averages.assign_moving_average(_averages[var], var, decay, zero_debias: zero_debias);
  66. updates.Add(ama);
  67. }
  68. return control_flow_ops.group(updates.ToArray(), name: scope);
  69. });
  70. }
  71. }
  72. }