You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Optimizer.cs 8.0 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using distribute_lib = Tensorflow.Distribute;
  6. namespace Tensorflow
  7. {
  8. /// <summary>
  9. /// Base class for optimizers.
  10. /// This class defines the API to add Ops to train a model. You never use this
  11. /// class directly, but instead instantiate one of its subclasses such as
  12. /// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
  13. /// </summary>
  14. public abstract class Optimizer : Python
  15. {
  16. // Values for gate_gradients.
  17. public static int GATE_NONE = 0;
  18. public static int GATE_OP = 1;
  19. public static int GATE_GRAPH = 2;
  20. public string Name { get; set; }
  21. public float LearningRate { get; set; }
  22. public Tensor LearningRateTensor { get; set; }
  23. public bool _use_locking;
  24. public Dictionary<string, object> _slots;
  25. public Dictionary<string, object> _non_slot_dict;
  26. public Dictionary<string, object> _deferred_slot_restorations;
  27. public Optimizer(float learning_rate, bool use_locking, string name = null)
  28. {
  29. if (String.IsNullOrEmpty(name))
  30. throw new NotImplementedException("Must specify the optimizer name");
  31. Name = name;
  32. _use_locking = use_locking;
  33. // Dictionary of slots.
  34. _slots = new Dictionary<string, object>();
  35. _non_slot_dict = new Dictionary<string, object>();
  36. _deferred_slot_restorations = new Dictionary<string, object>();
  37. }
  38. /// <summary>
  39. /// Add operations to minimize `loss` by updating `var_list`
  40. /// </summary>
  41. /// <param name="loss"></param>
  42. /// <returns>
  43. /// An Operation that updates the variables in `var_list`. If `global_step`
  44. /// was not `None`, that operation also increments `global_step`.
  45. /// </returns>
  46. public Operation minimize(Tensor loss,
  47. GateGradientType gate_gradients = GateGradientType.GATE_OP,
  48. bool colocate_gradients_with_ops = false)
  49. {
  50. var grads_and_vars = compute_gradients(loss,
  51. gate_gradients: gate_gradients,
  52. colocate_gradients_with_ops: colocate_gradients_with_ops);
  53. var vars_with_grad = grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray();
  54. if (vars_with_grad.Length == 0)
  55. throw new ValueError($"No gradients provided for any variable, check your graph for ops" +
  56. $" that do not support gradients, between variables {string.Join(",", vars_with_grad.Select(x => x.name))} and loss {loss}.");
  57. return apply_gradients(grads_and_vars);
  58. }
  59. public Operation apply_gradients(Tuple<Tensor, RefVariable>[] grads_and_vars, Tensor global_step = null, string name = null)
  60. {
  61. // No DistributionStrategy case.
  62. var converted_grads_and_vars = new List<Tuple<Tensor, RefVariable, _OptimizableVariable>>();
  63. foreach (var (g, v) in grads_and_vars)
  64. {
  65. if(g != null)
  66. {
  67. // Convert the grad to Tensor or IndexedSlices if necessary.
  68. var gR = ops.convert_to_tensor_or_indexed_slices(g);
  69. var p = _get_processor(v);
  70. converted_grads_and_vars.Add(new Tuple<Tensor, RefVariable, _OptimizableVariable>(gR, v, p));
  71. }
  72. }
  73. var var_list = converted_grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray();
  74. if (var_list.Length == 0)
  75. throw new ValueError($"No gradients provided for any variable");
  76. ops.init_scope();
  77. _create_slots(var_list);
  78. var update_ops = new List<Operation>();
  79. return with(new ops.name_scope(name, Name), scope =>
  80. {
  81. name = scope;
  82. _prepare();
  83. foreach(var (grad, var, processor) in converted_grads_and_vars)
  84. {
  85. if (grad == null)
  86. continue;
  87. var scope_name = var.op.name;
  88. with(new ops.name_scope("update_" + scope_name), scope2 =>
  89. {
  90. update_ops.Add(processor.update_op(this, grad));
  91. });
  92. }
  93. Operation apply_updates = null;
  94. if (global_step == null)
  95. {
  96. apply_updates = _finish(update_ops.ToArray(), name);
  97. }
  98. else
  99. {
  100. }
  101. if (!tf.context.executing_eagerly())
  102. {
  103. var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<object>;
  104. if (!train_op.Contains(apply_updates))
  105. train_op.Add(apply_updates);
  106. }
  107. return apply_updates;
  108. });
  109. }
  110. private void _create_slots(RefVariable[] var_list)
  111. {
  112. }
  113. public virtual Operation _finish(Operation[] update_ops, string name_scope)
  114. {
  115. return control_flow_ops.group(update_ops, name_scope);
  116. }
  117. public virtual Operation _apply_dense(Tensor grad, RefVariable var)
  118. {
  119. var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype());
  120. return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op;
  121. }
  122. public virtual void _prepare()
  123. {
  124. }
  125. private _OptimizableVariable _get_processor(RefVariable v)
  126. {
  127. if(v is RefVariable)
  128. {
  129. return new _RefVariableProcessor(v);
  130. }
  131. else
  132. {
  133. throw new NotImplementedException("_get_processor");
  134. }
  135. }
  136. /// <summary>
  137. /// Compute gradients of `loss` for the variables in `var_list`.
  138. /// </summary>
  139. /// <param name="loss"></param>
  140. /// <param name="gate_gradients"></param>
  141. /// <returns>
  142. /// A list of (gradient, variable) pairs. Variable is always present, but
  143. /// gradient can be `None`.
  144. /// </returns>
  145. public Tuple<Tensor, RefVariable>[] compute_gradients(Tensor loss,
  146. List<RefVariable> var_list = null,
  147. int? aggregation_method = null,
  148. GateGradientType gate_gradients = GateGradientType.GATE_OP,
  149. bool colocate_gradients_with_ops = false,
  150. Tensor[] grad_loss = null)
  151. {
  152. int num_towers = 1;
  153. if(distribute_lib.get_loss_reduction() == VariableAggregationType.MEAN)
  154. {
  155. }
  156. var tmp = variables.trainable_variables();
  157. switch (tmp)
  158. {
  159. case List<RefVariable> values:
  160. var_list = values;
  161. break;
  162. }
  163. var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
  164. var var_refs = processors.Select(x => x.target()).ToArray();
  165. var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss,
  166. gate_gradients: (gate_gradients == GateGradientType.GATE_OP),
  167. aggregation_method: aggregation_method,
  168. colocate_gradients_with_ops: colocate_gradients_with_ops);
  169. if ((int)gate_gradients == Optimizer.GATE_GRAPH)
  170. grads = control_flow_ops.tuple(grads);
  171. var grads_and_vars = Python.zip(grads, var_list)
  172. .Select(x => new Tuple<Tensor, RefVariable>(x.Item1, x.Item2))
  173. .ToArray();
  174. return grads_and_vars;
  175. }
  176. protected T _call_if_callable<T>(T param)
  177. {
  178. return param;
  179. }
  180. }
  181. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。