You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Optimizer.cs 3.8 kB

6 years ago
6 years ago
6 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using distribute_lib = Tensorflow.Distribute;
  6. namespace Tensorflow
  7. {
  8. /// <summary>
  9. /// Base class for optimizers.
  10. /// This class defines the API to add Ops to train a model. You never use this
  11. /// class directly, but instead instantiate one of its subclasses such as
  12. /// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
  13. /// </summary>
  14. public abstract class Optimizer
  15. {
  16. public string Name { get; set; }
  17. public double LearningRate { get; set; }
  18. public Tensor LearningRateTensor { get; set; }
  19. public bool _use_locking;
  20. public Dictionary<string, object> _slots;
  21. public Dictionary<string, object> _non_slot_dict;
  22. public Dictionary<string, object> _deferred_slot_restorations;
  23. public Optimizer(double learning_rate, bool use_locking, string name = "")
  24. {
  25. if (String.IsNullOrEmpty(name))
  26. throw new NotImplementedException("Must specify the optimizer name");
  27. Name = name;
  28. _use_locking = use_locking;
  29. // Dictionary of slots.
  30. _slots = new Dictionary<string, object>();
  31. _non_slot_dict = new Dictionary<string, object>();
  32. _deferred_slot_restorations = new Dictionary<string, object>();
  33. }
  34. /// <summary>
  35. /// Add operations to minimize `loss` by updating `var_list`
  36. /// </summary>
  37. /// <param name="loss"></param>
  38. /// <returns>
  39. /// An Operation that updates the variables in `var_list`. If `global_step`
  40. /// was not `None`, that operation also increments `global_step`.
  41. /// </returns>
  42. public Operation minimize(Tensor loss,
  43. GateGradientType gate_gradients = GateGradientType.GATE_OP,
  44. bool colocate_gradients_with_ops = false)
  45. {
  46. var grads_and_vars = compute_gradients(loss,
  47. gate_gradients: gate_gradients,
  48. colocate_gradients_with_ops: colocate_gradients_with_ops);
  49. return null;
  50. }
  51. /// <summary>
  52. /// Compute gradients of `loss` for the variables in `var_list`.
  53. /// </summary>
  54. /// <param name="loss"></param>
  55. /// <param name="gate_gradients"></param>
  56. /// <returns>
  57. /// A list of (gradient, variable) pairs. Variable is always present, but
  58. /// gradient can be `None`.
  59. /// </returns>
  60. public List<KeyValuePair<object, object>> compute_gradients(Tensor loss,
  61. List<RefVariable> var_list = null,
  62. int? aggregation_method = null,
  63. GateGradientType gate_gradients = GateGradientType.GATE_OP,
  64. bool colocate_gradients_with_ops = false,
  65. List<Tensor> grad_loss = null)
  66. {
  67. int num_towers = 1;
  68. if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN)
  69. {
  70. }
  71. var tmp = variables.trainable_variables();
  72. switch (tmp)
  73. {
  74. case List<RefVariable> values:
  75. var_list = values;
  76. break;
  77. }
  78. var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
  79. var var_refs = processors.Select(x => x.target()).ToList();
  80. gradients_impl.gradients(loss, var_refs, grad_ys: grad_loss,
  81. gate_gradients: (gate_gradients == GateGradientType.GATE_OP),
  82. aggregation_method: aggregation_method,
  83. colocate_gradients_with_ops: colocate_gradients_with_ops);
  84. return null;
  85. }
  86. }
  87. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。