You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Optimizer.cs 19 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. /*****************************************************************************
  2. Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. ******************************************************************************/
  13. using System;
  14. using System.Collections.Generic;
  15. using System.Linq;
  16. using Tensorflow.Framework;
  17. using Tensorflow.Train;
  18. using static Tensorflow.Python;
  19. namespace Tensorflow
  20. {
  21. /// <summary>
  22. /// Base class for optimizers.
  23. /// This class defines the API to add Ops to train a model. You never use this
  24. /// class directly, but instead instantiate one of its subclasses such as
  25. /// `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
  26. /// </summary>
  27. public abstract class Optimizer : Trackable
  28. {
  29. // Values for gate_gradients.
  30. public static int GATE_NONE = 0;
  31. public static int GATE_OP = 1;
  32. public static int GATE_GRAPH = 2;
  33. string _name;
  34. public string Name => _name;
  35. protected float _lr;
  36. public float LearningRate => _lr;
  37. protected Tensor _lr_t;
  38. public Tensor LearningRateTensor => _lr_t;
  39. public bool _use_locking;
  40. public Dictionary<string, Dictionary<string, RefVariable>> _slots;
  41. public Dictionary<string, RefVariable> _non_slot_dict;
  42. public Dictionary<string, object> _deferred_slot_restorations;
  43. SlotCreator slot_creator = new SlotCreator();
  44. public Optimizer(float learning_rate, bool use_locking, string name = null)
  45. {
  46. if (String.IsNullOrEmpty(name))
  47. throw new NotImplementedException("Must specify the optimizer name");
  48. _name = name;
  49. _use_locking = use_locking;
  50. _lr = learning_rate;
  51. // Dictionary of slots.
  52. _slots = new Dictionary<string, Dictionary<string, RefVariable>>();
  53. _non_slot_dict = new Dictionary<string, RefVariable>();
  54. _deferred_slot_restorations = new Dictionary<string, object>();
  55. }
  56. /// <summary>
  57. /// Add operations to minimize `loss` by updating `var_list`
  58. ///
  59. /// This method simply combines calls `compute_gradients()` and
  60. /// `apply_gradients()`. If you want to process the gradient before applying
  61. /// them call `compute_gradients()` and `apply_gradients()` explicitly instead
  62. /// of using this function.
  63. /// </summary>
  64. /// <param name="loss">A `Tensor` containing the value to minimize.</param>
  65. /// <param name="global_step">Optional `Variable` to increment by one after the
  66. /// variables have been updated.</param>
  67. /// <param name="var_list">Optional list or tuple of `Variable` objects to update to
  68. /// minimize `loss`. Defaults to the list of variables collected in
  69. /// the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.</param>
  70. /// <param name="gate_gradients">
  71. /// How to gate the computation of gradients. Can be
  72. /// `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
  73. /// </param>
  74. /// <param name="aggregation_method">
  75. /// Specifies the method used to combine gradient terms.
  76. /// Valid values are defined in the class `AggregationMethod`.
  77. /// </param>
  78. /// <param name="colocate_gradients_with_ops"></param>
  79. /// <param name="name">Optional name for the returned operation.</param>
  80. /// <param name="grad_loss">Optional. A `Tensor` holding the gradient computed for `loss`.</param>
  81. /// <returns>
  82. /// An Operation that updates the variables in `var_list`. If `global_step`
  83. /// was not `None`, that operation also increments `global_step`.
  84. /// </returns>
  85. public Operation minimize(Tensor loss,
  86. RefVariable global_step = null,
  87. List<RefVariable> var_list=null,
  88. GateGradientType gate_gradients = GateGradientType.GATE_OP,
  89. int? aggregation_method=null,
  90. bool colocate_gradients_with_ops = false, string name=null, Tensor grad_loss=null)
  91. {
  92. // TODO: strongly type aggregation_method
  93. var grads_and_vars = compute_gradients(loss, var_list:var_list,
  94. gate_gradients: gate_gradients,
  95. aggregation_method:aggregation_method,
  96. colocate_gradients_with_ops: colocate_gradients_with_ops,
  97. grad_loss: grad_loss);
  98. var vars_with_grad = grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray();
  99. if (vars_with_grad.Length == 0)
  100. throw new ValueError($"No gradients provided for any variable, check your graph for ops" +
  101. $" that do not support gradients, between variables {string.Join(",", vars_with_grad.Select(x => x.name))} and loss {loss}.");
  102. return apply_gradients(grads_and_vars, global_step:global_step, name:name);
  103. }
  104. /// <summary>
  105. /// Apply gradients to variables.
  106. ///
  107. /// This is the second part of `minimize()`. It returns an `Operation` that
  108. /// applies gradients.
  109. /// </summary>
  110. /// <param name="grads_and_vars">List of (gradient, variable) pairs as returned by
  111. /// `compute_gradients()`.</param>
  112. /// <param name="global_step">Optional `Variable` to increment by one after the
  113. /// variables have been updated.</param>
  114. /// <param name="name">Optional name for the returned operation. Default to the
  115. /// name passed to the `Optimizer` constructor.</param>
  116. /// <returns>
  117. /// An `Operation` that applies the specified gradients. If `global_step`
  118. /// was not None, that operation also increments `global_step`.</returns>
  119. public Operation apply_gradients(Tuple<Tensor, RefVariable>[] grads_and_vars, RefVariable global_step = null, string name = null)
  120. {
  121. // No DistributionStrategy case.
  122. var converted_grads_and_vars = new List<(Tensor, RefVariable, _OptimizableVariable)>();
  123. foreach (var (g, v) in grads_and_vars)
  124. {
  125. if(g != null)
  126. {
  127. // Convert the grad to Tensor or IndexedSlices if necessary.
  128. var gR = ops.convert_to_tensor_or_indexed_slices(g);
  129. var p = _get_processor(v);
  130. converted_grads_and_vars.Add((gR, v, p));
  131. }
  132. }
  133. var var_list = converted_grads_and_vars.Where(x => x.Item1 != null).Select(x => x.Item2).ToArray();
  134. if (var_list.Length == 0)
  135. throw new ValueError($"No gradients provided for any variable");
  136. ops.init_scope();
  137. _create_slots(var_list);
  138. var update_ops = new List<Operation>();
  139. return tf_with(ops.name_scope(name, Name), scope =>
  140. {
  141. name = scope;
  142. _prepare();
  143. foreach(var (grad, var, processor) in converted_grads_and_vars)
  144. {
  145. if (grad == null)
  146. continue;
  147. var scope_name = var.op.name;
  148. tf_with(ops.name_scope("update_" + scope_name), scope2 =>
  149. {
  150. var op = processor.update_op(this, grad);
  151. update_ops.Add(op);
  152. });
  153. }
  154. Operation apply_updates = null;
  155. if (global_step == null)
  156. {
  157. apply_updates = _finish(update_ops.ToArray(), name);
  158. }
  159. else
  160. {
  161. tf_with(ops.control_dependencies(new object[] {_finish(update_ops.ToArray(), "update")}), dep =>
  162. {
  163. ops.colocate_with(global_step);
  164. // TODO: port this if branch once ResourceVariable has been ported!
  165. //if (global_step is ResourceVariable)
  166. //{
  167. // # TODO(apassos): the implicit read in assign_add is slow; consider
  168. // # making it less so.
  169. // apply_updates = resource_variable_ops.assign_add_variable_op(
  170. // global_step.handle,
  171. // ops.convert_to_tensor(1, dtype = global_step.dtype),
  172. // name = name)
  173. //}
  174. //else
  175. {
  176. apply_updates = state_ops.assign_add(global_step, tf.constant(1), name: name);
  177. }
  178. });
  179. }
  180. if (!tf.context.executing_eagerly())
  181. {
  182. var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<ITensorOrOperation>;
  183. if (train_op != null && train_op.Contains(apply_updates))
  184. train_op.Add(apply_updates);
  185. }
  186. return apply_updates;
  187. });
  188. }
  189. /// <summary>
  190. /// Create the beta1 and beta2 accumulators on the same device as the first
  191. /// variable. Sort the var_list to make sure this device is consistent across
  192. /// workers (these need to go on the same PS, otherwise some updates are
  193. /// silently ignored).
  194. /// </summary>
  195. /// <param name="var_list"></param>
  196. protected virtual void _create_slots(RefVariable[] var_list)
  197. {
  198. }
  199. /// <summary>
  200. /// Add an extra variable, not associated with a slot.
  201. /// </summary>
  202. /// <param name="initial_value"></param>
  203. /// <param name="name"></param>
  204. /// <param name="colocate_with"></param>
  205. protected RefVariable _create_non_slot_variable(float initial_value, string name, RefVariable colocate_with)
  206. {
  207. // Recommendation: Use OptimizerV2 if your optimizer uses non-slot variables.
  208. var graph = colocate_with.graph;
  209. var key = $"{name}.{graph.graph_key}";
  210. var v = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null;
  211. if(v == null)
  212. {
  213. _maybe_initialize_trackable();
  214. v = variable_scope.default_variable_creator(
  215. initial_value, name: name, trainable: false,
  216. use_resource: resource_variable_ops.is_resource_variable(
  217. colocate_with));
  218. // Restore this variable by name if necessary, but don't add a
  219. // Trackable dependency. Optimizers return the current graph's
  220. // non-slot variables from _checkpoint_dependencies explicitly rather
  221. // than unconditionally adding dependencies (since there may be multiple
  222. // non-slot variables with the same name in different graphs, trying to
  223. // save all of them would result in errors).
  224. _handle_deferred_dependencies(name, v);
  225. _non_slot_dict[key] = v;
  226. }
  227. return v;
  228. }
  229. public virtual Operation _finish(Operation[] update_ops, string name_scope)
  230. {
  231. return control_flow_ops.group(update_ops, name_scope);
  232. }
  233. public virtual Operation _apply_dense(Tensor grad, RefVariable var)
  234. {
  235. var alpha = math_ops.cast(LearningRateTensor, var.dtype.as_base_dtype());
  236. return gen_training_ops.apply_gradient_descent(var, alpha, grad, use_locking: _use_locking).op;
  237. }
  238. /// <summary>
  239. /// Add ops to apply sparse gradients to `var`, with repeated sparse indices.
  240. /// </summary>
  241. /// <param name="grad"></param>
  242. /// <param name="var"></param>
  243. /// <returns></returns>
  244. public virtual Operation _apply_sparse_duplicate_indices(IndexedSlices grad, RefVariable var)
  245. {
  246. var (summed_values, unique_indices) = _deduplicate_indexed_slices(values: grad.values, indices: grad.indices);
  247. var gradient_no_duplicate_indices = new IndexedSlices(
  248. indices: unique_indices,
  249. values: summed_values,
  250. dense_shape: grad.dense_shape);
  251. return _apply_sparse(gradient_no_duplicate_indices, var);
  252. }
  253. public virtual Operation _apply_sparse(IndexedSlices grad, RefVariable var)
  254. {
  255. throw new NotImplementedException("_apply_sparse");
  256. }
  257. public virtual (Tensor, Tensor) _deduplicate_indexed_slices(Tensor values, Tensor indices)
  258. {
  259. var (unique_indices, new_index_positions) = array_ops.unique(indices);
  260. var shape = array_ops.shape(unique_indices).slice(0);
  261. var summed_values = math_ops.unsorted_segment_sum(values, new_index_positions, shape);
  262. return (summed_values, unique_indices);
  263. }
  264. public virtual void _prepare()
  265. {
  266. }
  267. /// <summary>
  268. /// Return a slot named `name` created for `var` by the Optimizer.
  269. /// </summary>
  270. /// <param name="var"></param>
  271. /// <param name="name"></param>
  272. /// <returns></returns>
  273. protected RefVariable get_slot(RefVariable var, string name)
  274. {
  275. var named_slots = _slots.ContainsKey(name) ? _slots[name] : null;
  276. if (named_slots == null)
  277. return null;
  278. return named_slots.ContainsKey(_var_key(var)) ? named_slots[_var_key(var)] : null;
  279. }
  280. private string _var_key(RefVariable var)
  281. {
  282. return $"{var.op.graph.graph_key}.{var.op.name}";
  283. }
  284. protected RefVariable _get_non_slot_variable(string name, Graph graph = null)
  285. {
  286. var key = $"{name}.{graph.graph_key}";
  287. var non_slot = _non_slot_dict.ContainsKey(key) ? _non_slot_dict[key] : null;
  288. return non_slot;
  289. }
  290. private _OptimizableVariable _get_processor(RefVariable v)
  291. {
  292. if(v is RefVariable)
  293. {
  294. return new _RefVariableProcessor(v);
  295. }
  296. else
  297. {
  298. throw new NotImplementedException("_get_processor");
  299. }
  300. }
  301. /// <summary>
  302. /// Compute gradients of `loss` for the variables in `var_list`.
  303. /// </summary>
  304. /// <param name="loss"></param>
  305. /// <param name="gate_gradients"></param>
  306. /// <returns>
  307. /// A list of (gradient, variable) pairs. Variable is always present, but
  308. /// gradient can be `None`.
  309. /// </returns>
  310. public Tuple<Tensor, RefVariable>[] compute_gradients(Tensor loss,
  311. List<RefVariable> var_list = null,
  312. int? aggregation_method = null,
  313. GateGradientType gate_gradients = GateGradientType.GATE_OP,
  314. bool colocate_gradients_with_ops = false,
  315. Tensor grad_loss = null)
  316. {
  317. // Scale loss if using a "mean" loss reduction and multiple replicas.
  318. loss = _scale_loss(loss);
  319. int num_towers = 1;
  320. var tmp = variables.trainable_variables();
  321. var vars = ops.get_collection<RefVariable>(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES);
  322. switch (tmp)
  323. {
  324. case List<RefVariable> values:
  325. var_list = values.Concat(vars).ToList();
  326. break;
  327. case List<VariableV1> values:
  328. var_list = values.Select(x => x as RefVariable).Concat(vars).ToList();
  329. break;
  330. }
  331. var_list = var_list.Concat(ops.get_collection<RefVariable>(ops.GraphKeys._STREAMING_MODEL_PORTS)).ToList();
  332. var processors = var_list.Select(v => optimizer._get_processor(v)).ToList();
  333. var var_refs = processors.Select(x => x.target()).ToArray();
  334. var grads = gradients_impl.gradients(new Tensor[] { loss }, var_refs, grad_ys: grad_loss == null ? null : new Tensor[] { grad_loss },
  335. gate_gradients: gate_gradients == GateGradientType.GATE_OP,
  336. aggregation_method: aggregation_method,
  337. colocate_gradients_with_ops: colocate_gradients_with_ops);
  338. if ((int)gate_gradients == Optimizer.GATE_GRAPH)
  339. grads = control_flow_ops.tuple(grads);
  340. var grads_and_vars = Python.zip(grads, var_list)
  341. .Select(x => new Tuple<Tensor, RefVariable>(x.Item1, x.Item2))
  342. .ToArray();
  343. return grads_and_vars;
  344. }
  345. private Tensor _scale_loss(Tensor loss_value)
  346. {
  347. ops.get_default_graph()._is_loss_scaled_by_optimizer = false;
  348. // TODO
  349. // if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
  350. return loss_value;
  351. }
  352. protected T _call_if_callable<T>(T param)
  353. {
  354. return param;
  355. }
  356. /// <summary>
  357. /// Find or create a slot initialized with 0.0.
  358. /// </summary>
  359. /// <param name="var"></param>
  360. /// <param name="slot_name"></param>
  361. /// <param name="op_name"></param>
  362. /// <returns></returns>
  363. protected RefVariable _zeros_slot(RefVariable var, string slot_name, string op_name)
  364. {
  365. var named_slots = _slot_dict(slot_name);
  366. if (!named_slots.ContainsKey(_var_key(var)))
  367. {
  368. var new_slot_variable = slot_creator.create_zeros_slot(var, op_name);
  369. _restore_slot_variable(slot_name: slot_name, variable: var, slot_variable: new_slot_variable);
  370. named_slots[_var_key(var)] = new_slot_variable;
  371. }
  372. return named_slots[_var_key(var)];
  373. }
  374. /// <summary>
  375. /// Restore a newly created slot variable's value.
  376. /// </summary>
  377. protected void _restore_slot_variable(string slot_name, RefVariable variable, RefVariable slot_variable)
  378. {
  379. var variable_key = _var_key(variable);
  380. // TODO
  381. }
  382. protected Dictionary<string, RefVariable> _slot_dict(string slot_name)
  383. {
  384. var named_slots = _slots.ContainsKey(slot_name) ? _slots[slot_name] : null;
  385. if(named_slots == null)
  386. {
  387. named_slots = new Dictionary<string, RefVariable>();
  388. _slots[slot_name] = named_slots;
  389. }
  390. return named_slots;
  391. }
  392. }
  393. }