You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

OptimizerV2.cs 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Keras.ArgsDefinition;
  5. using Tensorflow.Keras.Utils;
  6. using Tensorflow.Train;
  7. using static Tensorflow.Binding;
  8. namespace Tensorflow.Keras.Optimizers
  9. {
  10. /// <summary>
  11. /// Updated base class for optimizers.
  12. /// </summary>
  13. public class OptimizerV2 : Trackable, IOptimizer
  14. {
  15. OptimizerV2Args args;
  16. protected bool _hypers_created;
  17. protected virtual string _name { get; }
  18. IVariableV1 _iterations;
  19. protected ResourceVariable iterations => _iterations as ResourceVariable;
  20. List<IVariableV1> _weights;
  21. Dictionary<string, float> _hyper;
  22. Dictionary<string, IVariableV1> _hyper_variables;
  23. protected bool _momentum;
  24. protected float _initial_decay = 0.0f;
  25. protected bool _use_locking = true;
  26. public IVariableV1 lr
  27. => _hyper_variables["learning_rate"];
  28. Dictionary<string, Dictionary<string, IVariableV1>> _slots;
  29. List<string> _slot_names;
  30. public OptimizerV2(OptimizerV2Args args) : base()
  31. {
  32. this.args = args;
  33. _weights = new List<IVariableV1>();
  34. _hyper = new Dictionary<string, float>();
  35. _hyper_variables = new Dictionary<string, IVariableV1>();
  36. _slots = new Dictionary<string, Dictionary<string, IVariableV1>>();
  37. _slot_names = new List<string>();
  38. _set_hyper("learning_rate", args.LearningRate);
  39. _set_hyper("decay", args.InitialDecay);
  40. }
  41. public void apply_gradients((Tensor, ResourceVariable) grads_and_vars,
  42. string name = null,
  43. bool experimental_aggregate_gradients = true)
  44. => apply_gradients(new[] { grads_and_vars },
  45. name: name,
  46. experimental_aggregate_gradients: experimental_aggregate_gradients);
  47. /// <summary>
  48. /// Apply gradients to variables.
  49. /// </summary>
  50. /// <param name="grads_and_vars"></param>
  51. /// <param name="name"></param>
  52. /// <param name="experimental_aggregate_gradients"></param>
  53. public void apply_gradients(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
  54. string name = null,
  55. bool experimental_aggregate_gradients = true)
  56. {
  57. var var_list = grads_and_vars.Select(x => x.Item2).ToArray();
  58. tf_with(ops.name_scope(_name), delegate
  59. {
  60. ops.init_scope();
  61. _create_all_weights(var_list);
  62. if (grads_and_vars == null || grads_and_vars.Count() == 0)
  63. return control_flow_ops.no_op();
  64. var apply_state = _prepare(var_list);
  65. // if(experimental_aggregate_gradients)
  66. {
  67. // var reduced_grads = _aggregate_gradients(grads_and_vars);
  68. _distributed_apply(grads_and_vars, name, apply_state);
  69. }
  70. return null;
  71. });
  72. }
  73. void apply_grad_to_update_var(ResourceVariable var, Tensor grad, Dictionary<DeviceDType, Dictionary<string, Tensor>> apply_state)
  74. {
  75. _resource_apply_dense(var, grad, apply_state);
  76. // if var.constraint is not None:
  77. // with ops.control_dependencies([update_op]):
  78. // return var.assign(var.constraint(var))
  79. }
  80. protected virtual Operation _resource_apply_dense(IVariableV1 var,
  81. Tensor grad,
  82. Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
  83. {
  84. throw new NotImplementedException("_resource_apply_dense");
  85. }
  86. void _distributed_apply(IEnumerable<(Tensor, ResourceVariable)> grads_and_vars,
  87. string name,
  88. Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
  89. {
  90. tf_with(ops.name_scope(name, "", new { skip_on_eager = true }), delegate
  91. {
  92. foreach (var (grad, var) in grads_and_vars)
  93. {
  94. tf_with(ops.name_scope("update"), delegate
  95. {
  96. apply_grad_to_update_var(var, grad, _apply_state);
  97. });
  98. }
  99. _iterations.assign_add(ops.convert_to_tensor(1, dtype: _iterations.dtype));
  100. });
  101. }
  102. public Tensor[] _aggregate_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars)
  103. {
  104. return grads_and_vars.Select(x => x.Item1).ToArray();
  105. }
  106. public Tensor[] _clip_gradients(Tensor[] grads)
  107. {
  108. return grads;
  109. }
  110. protected IVariableV1 get_slot(IVariableV1 var, string slot_name)
  111. {
  112. var slot_dict = _slots[var.UniqueId];
  113. return slot_dict[slot_name];
  114. }
  115. Dictionary<DeviceDType, Dictionary<string, Tensor>> _prepare(IVariableV1[] var_list)
  116. {
  117. var _apply_state = new Dictionary<DeviceDType, Dictionary<string, Tensor>>();
  118. var keys = var_list.Select(x => new DeviceDType
  119. {
  120. Device = x.Device,
  121. DType = x.dtype.as_base_dtype()
  122. }).Distinct(new DeviceDType()).ToArray();
  123. foreach (var device_dtype in keys)
  124. {
  125. _apply_state[device_dtype] = new Dictionary<string, Tensor>();
  126. _prepare_local(device_dtype, _apply_state);
  127. }
  128. return _apply_state;
  129. }
  130. protected Dictionary<string, Tensor> _fallback_apply_state(string var_device, TF_DataType var_dtype)
  131. {
  132. throw new NotImplementedException("");
  133. }
  134. protected virtual void _prepare_local(DeviceDType device_dtype,
  135. Dictionary<DeviceDType, Dictionary<string, Tensor>> _apply_state)
  136. {
  137. if (_hyper.ContainsKey("learning_rate"))
  138. {
  139. var lr_t = array_ops.identity(_decayed_lr(device_dtype.DType));
  140. _apply_state[device_dtype]["lr_t"] = lr_t;
  141. }
  142. }
  143. Tensor _decayed_lr(TF_DataType var_dtype)
  144. {
  145. var lr_t = _get_hyper("learning_rate", var_dtype);
  146. if (_initial_decay > 0.0f)
  147. {
  148. throw new NotImplementedException("");
  149. }
  150. return lr_t;
  151. }
  152. protected Tensor _get_hyper(string name, TF_DataType dtype = TF_DataType.DtInvalid)
  153. {
  154. var value = _hyper_variables[name];
  155. return math_ops.cast(value, dtype);
  156. }
  157. void _create_all_weights(IVariableV1[] var_list)
  158. {
  159. if (_iterations == null)
  160. {
  161. _iterations = add_weight("iter",
  162. shape: new int[0],
  163. dtype: TF_DataType.TF_INT64,
  164. trainable: false,
  165. aggregation: VariableAggregation.OnlyFirstReplica);
  166. _weights.Add(_iterations);
  167. }
  168. _create_hypers();
  169. _create_slots(var_list);
  170. }
  171. protected void _set_hyper(string name, float value)
  172. {
  173. _hyper[name] = value;
  174. }
  175. void _create_hypers()
  176. {
  177. if (_hypers_created)
  178. return;
  179. foreach (var dict in _hyper)
  180. {
  181. var name = dict.Key;
  182. var value = dict.Value;
  183. _hyper_variables[name] = add_weight(
  184. name,
  185. shape: new int[0],
  186. trainable: false,
  187. initializer: tf.constant_initializer(value),
  188. aggregation: VariableAggregation.OnlyFirstReplica);
  189. }
  190. _hypers_created = true;
  191. }
  192. protected virtual void _create_slots(IVariableV1[] var_list)
  193. {
  194. if (_momentum)
  195. {
  196. /*for var in var_list:
  197. self.add_slot(var, "momentum")*/
  198. }
  199. }
  200. protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null)
  201. {
  202. if (initializer == null)
  203. initializer = tf.zeros_initializer;
  204. if (!_slot_names.Contains(slot_name))
  205. _slot_names.append(slot_name);
  206. if (!_slots.ContainsKey(var.UniqueId))
  207. _slots[var.UniqueId] = new Dictionary<string, IVariableV1>();
  208. var slot_dict = _slots[var.UniqueId];
  209. if (!slot_dict.ContainsKey(slot_name))
  210. {
  211. var weight = tf.Variable(initializer,
  212. dtype: var.dtype,
  213. trainable: false,
  214. shape: var.shape,
  215. name: $"{var.Name}/{slot_name}");
  216. slot_dict[slot_name] = weight;
  217. _weights.append(weight);
  218. return weight;
  219. }
  220. else
  221. {
  222. return slot_dict[slot_name];
  223. }
  224. }
  225. ResourceVariable add_weight(string name,
  226. Shape shape,
  227. TF_DataType dtype = TF_DataType.TF_FLOAT,
  228. IInitializer initializer = null,
  229. bool trainable = false,
  230. VariableSynchronization synchronization = VariableSynchronization.Auto,
  231. VariableAggregation aggregation = VariableAggregation.None)
  232. {
  233. if (initializer == null)
  234. initializer = tf.zeros_initializer;
  235. if (dtype == TF_DataType.DtInvalid)
  236. dtype = TF_DataType.TF_FLOAT;
  237. var variable = _add_variable_with_custom_getter(new VariableArgs
  238. {
  239. Name = name,
  240. Shape = shape,
  241. Getter = base_layer_utils.make_variable,
  242. DType = dtype,
  243. Overwrite = true,
  244. Initializer = initializer,
  245. Trainable = trainable,
  246. UseResource = true,
  247. Synchronization = synchronization,
  248. Aggregation = aggregation
  249. });
  250. return variable as ResourceVariable;
  251. }
  252. }
  253. }