You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Fit.cs 12 kB

2 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. using Tensorflow.NumPy;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow.Keras.ArgsDefinition;
  6. using Tensorflow.Keras.Engine.DataAdapters;
  7. using System.Diagnostics;
  8. using Tensorflow.Keras.Callbacks;
  9. using Tensorflow.Util;
  10. namespace Tensorflow.Keras.Engine
  11. {
  12. public partial class Model
  13. {
  14. /// <summary>
  15. /// Trains the model for a fixed number of epochs (iterations on a dataset).
  16. /// </summary>
  17. /// <param name="x"></param>
  18. /// <param name="y"></param>
  19. /// <param name="batch_size"></param>
  20. /// <param name="epochs"></param>
  21. /// <param name="verbose"></param>
  22. /// <param name="callbacks"></param>
  23. /// <param name="validation_split"></param>
  24. /// <param name="validation_data"></param>
  25. /// <param name="shuffle"></param>
  26. /// <param name="class_weight"></param>
  27. /// <param name="sample_weight"></param>
  28. /// <param name="initial_epoch"></param>
  29. /// <param name="max_queue_size"></param>
  30. /// <param name="workers"></param>
  31. /// <param name="use_multiprocessing"></param>
  32. /// <returns></returns>
  33. /// <exception cref="InvalidArgumentError"></exception>
  34. public ICallback fit(NDArray x, NDArray y,
  35. int batch_size = -1,
  36. int epochs = 1,
  37. int verbose = 1,
  38. List<ICallback> callbacks = null,
  39. float validation_split = 0f,
  40. ValidationDataPack validation_data = null,
  41. int validation_step = 10,
  42. bool shuffle = true,
  43. Dictionary<int, float> class_weight = null,
  44. NDArray sample_weight = null,
  45. int initial_epoch = 0,
  46. int max_queue_size = 10,
  47. int workers = 1,
  48. bool use_multiprocessing = false)
  49. {
  50. if (x.dims[0] != y.dims[0])
  51. {
  52. throw new InvalidArgumentError(
  53. $"The array x and y should have same value at dim 0, but got {x.dims[0]} and {y.dims[0]}");
  54. }
  55. // The default dtype in NDArray is double, so we need to cast sample_weight to float to mul with loss which's dtype is float.
  56. sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT);
  57. if (validation_split != 0f && validation_data == null)
  58. {
  59. ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split);
  60. }
  61. var data_handler = new DataHandler(new DataHandlerArgs
  62. {
  63. X = x,
  64. Y = y,
  65. SampleWeight = sample_weight,
  66. BatchSize = batch_size,
  67. InitialEpoch = initial_epoch,
  68. Epochs = epochs,
  69. Shuffle = shuffle,
  70. ClassWeight = class_weight,
  71. MaxQueueSize = max_queue_size,
  72. Workers = workers,
  73. UseMultiprocessing = use_multiprocessing,
  74. Model = this,
  75. StepsPerExecution = _steps_per_execution
  76. });
  77. return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
  78. train_step_func: train_step_function);
  79. }
  80. public ICallback fit(IEnumerable<NDArray> x, NDArray y,
  81. int batch_size = -1,
  82. int epochs = 1,
  83. int verbose = 1,
  84. List<ICallback> callbacks = null,
  85. float validation_split = 0f,
  86. ValidationDataPack validation_data = null,
  87. bool shuffle = true,
  88. Dictionary<int, float> class_weight = null,
  89. NDArray sample_weight = null,
  90. int initial_epoch = 0,
  91. int max_queue_size = 10,
  92. int workers = 1,
  93. bool use_multiprocessing = false)
  94. {
  95. foreach(var tx in x)
  96. {
  97. if (tx.dims[0] != y.dims[0])
  98. {
  99. throw new InvalidArgumentError(
  100. $"The array x and y should have same value at dim 0, but got {tx.dims[0]} and {y.dims[0]}");
  101. }
  102. }
  103. sample_weight = sample_weight?.astype(TF_DataType.TF_FLOAT);
  104. if (validation_split != 0f && validation_data == null)
  105. {
  106. ((x, y, sample_weight), validation_data) = DataAdapter.train_validation_split((x, y, sample_weight), validation_split);
  107. }
  108. var data_handler = new DataHandler(new DataHandlerArgs
  109. {
  110. X = new Tensors(x.ToArray()),
  111. Y = y,
  112. SampleWeight = sample_weight,
  113. BatchSize = batch_size,
  114. InitialEpoch = initial_epoch,
  115. Epochs = epochs,
  116. Shuffle = shuffle,
  117. ClassWeight = class_weight,
  118. MaxQueueSize = max_queue_size,
  119. Workers = workers,
  120. UseMultiprocessing = use_multiprocessing,
  121. Model = this,
  122. StepsPerExecution = _steps_per_execution
  123. });
  124. if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
  125. data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
  126. {
  127. return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
  128. train_step_func: train_step_multi_inputs_function);
  129. }
  130. else
  131. {
  132. return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: validation_data,
  133. train_step_func: train_step_function);
  134. }
  135. }
  136. public ICallback fit(IDatasetV2 dataset,
  137. int batch_size = -1,
  138. int epochs = 1,
  139. int verbose = 1,
  140. List<ICallback> callbacks = null,
  141. IDatasetV2 validation_data = null,
  142. int validation_step = 10,
  143. bool shuffle = true,
  144. Dictionary<int, float> class_weight = null,
  145. int initial_epoch = 0,
  146. int max_queue_size = 10,
  147. int workers = 1,
  148. bool use_multiprocessing = false)
  149. {
  150. var data_handler = new DataHandler(new DataHandlerArgs
  151. {
  152. Dataset = dataset,
  153. BatchSize = batch_size,
  154. InitialEpoch = initial_epoch,
  155. Epochs = epochs,
  156. Shuffle = shuffle,
  157. ClassWeight = class_weight,
  158. MaxQueueSize = max_queue_size,
  159. Workers = workers,
  160. UseMultiprocessing = use_multiprocessing,
  161. Model = this,
  162. StepsPerExecution = _steps_per_execution
  163. });
  164. return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data,
  165. train_step_func: train_step_function);
  166. }
  167. History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
  168. Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
  169. {
  170. stop_training = false;
  171. _train_counter.assign(0);
  172. var callbacks = new CallbackList(new CallbackParams
  173. {
  174. Model = this,
  175. Verbose = verbose,
  176. Epochs = epochs,
  177. Steps = data_handler.Inferredsteps
  178. });
  179. if (callbackList != null)
  180. {
  181. foreach(var callback in callbackList)
  182. callbacks.callbacks.add(callback);
  183. }
  184. callbacks.on_train_begin();
  185. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  186. {
  187. reset_metrics();
  188. callbacks.on_epoch_begin(epoch);
  189. // data_handler.catch_stop_iteration();
  190. var logs = new Dictionary<string, float>();
  191. long End_step = 0;
  192. foreach (var step in data_handler.steps())
  193. {
  194. callbacks.on_train_batch_begin(step);
  195. logs = train_step_func(data_handler, iterator);
  196. var end_step = step + data_handler.StepIncrement;
  197. End_step = end_step;
  198. callbacks.on_train_batch_end(end_step, logs);
  199. GC.Collect();
  200. }
  201. if (validation_data != null)
  202. {
  203. if (validation_step > 0 && epoch ==0 || (epoch) % validation_step != 0)
  204. continue;
  205. var val_logs = evaluate(validation_data);
  206. foreach(var log in val_logs)
  207. {
  208. logs["val_" + log.Key] = log.Value;
  209. }
  210. callbacks.on_train_batch_end(End_step, logs);
  211. }
  212. GC.Collect();
  213. callbacks.on_epoch_end(epoch, logs);
  214. if (stop_training)
  215. {
  216. break;
  217. }
  218. }
  219. return callbacks.History;
  220. }
  221. History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, ValidationDataPack validation_data,
  222. Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
  223. {
  224. stop_training = false;
  225. _train_counter.assign(0);
  226. var callbacks = new CallbackList(new CallbackParams
  227. {
  228. Model = this,
  229. Verbose = verbose,
  230. Epochs = epochs,
  231. Steps = data_handler.Inferredsteps
  232. });
  233. if (callbackList != null)
  234. {
  235. foreach (var callback in callbackList)
  236. callbacks.callbacks.add(callback);
  237. }
  238. callbacks.on_train_begin();
  239. foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
  240. {
  241. reset_metrics();
  242. callbacks.on_epoch_begin(epoch);
  243. // data_handler.catch_stop_iteration();
  244. var logs = new Dictionary<string, float>();
  245. long End_step = 0;
  246. foreach (var step in data_handler.steps())
  247. {
  248. callbacks.on_train_batch_begin(step);
  249. logs = train_step_func(data_handler, iterator);
  250. var end_step = step + data_handler.StepIncrement;
  251. End_step = end_step;
  252. callbacks.on_train_batch_end(end_step, logs);
  253. GC.Collect();
  254. }
  255. if (validation_data != null)
  256. {
  257. // Because evaluate calls call_test_batch_end, this interferes with our output on the screen
  258. // so we need to pass a is_val parameter to stop on_test_batch_end
  259. var (val_x, val_y, val_sample_weight) = validation_data;
  260. var val_logs = evaluate(val_x, val_y, sample_weight:val_sample_weight, is_val:true);
  261. foreach (var log in val_logs)
  262. {
  263. logs["val_" + log.Key] = log.Value;
  264. }
  265. // because after evaluate, logs add some new log which we need to print
  266. callbacks.on_train_batch_end(End_step, logs);
  267. }
  268. callbacks.on_epoch_end(epoch, logs);
  269. GC.Collect();
  270. if (stop_training)
  271. {
  272. break;
  273. }
  274. }
  275. return callbacks.History;
  276. }
  277. }
  278. }