You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ExponentialMovingAverage.cs 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using static Tensorflow.Binding;
  6. namespace Tensorflow.Train
  7. {
  8. public class ExponentialMovingAverage
  9. {
  10. float _decay;
  11. int? _num_updates;
  12. bool _zero_debias;
  13. string _name;
  14. public string name => _name;
  15. List<VariableV1> _averages;
  16. public ExponentialMovingAverage(float decay, int? num_updates = null, bool zero_debias = false,
  17. string name = "ExponentialMovingAverage")
  18. {
  19. _decay = decay;
  20. _num_updates = num_updates;
  21. _zero_debias = zero_debias;
  22. _name = name;
  23. _averages = new List<VariableV1>();
  24. }
  25. /// <summary>
  26. /// Maintains moving averages of variables.
  27. /// </summary>
  28. /// <param name="var_list"></param>
  29. /// <returns></returns>
  30. public Operation apply(RefVariable[] var_list = null)
  31. {
  32. if (var_list == null)
  33. var_list = variables.trainable_variables() as RefVariable[];
  34. foreach(var var in var_list)
  35. {
  36. if (!_averages.Contains(var))
  37. {
  38. ops.init_scope();
  39. var slot = new SlotCreator();
  40. var.initialized_value();
  41. // var avg = slot.create_zeros_slot
  42. }
  43. }
  44. throw new NotImplementedException("");
  45. }
  46. }
  47. }