You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Model.Compile.cs 3.4 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. using Tensorflow.Keras.ArgsDefinition;
  2. using Tensorflow.Keras.Losses;
  3. using Tensorflow.Keras.Metrics;
  4. using Tensorflow.Keras.Optimizers;
  5. namespace Tensorflow.Keras.Engine
  6. {
  7. public partial class Model
  8. {
  9. LossesContainer compiled_loss;
  10. MetricsContainer compiled_metrics;
  11. public void compile(IOptimizer optimizer,
  12. ILossFunc loss)
  13. {
  14. this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs
  15. {
  16. });
  17. this.loss = loss ?? new MeanSquaredError();
  18. compiled_loss = new LossesContainer(this.loss, output_names: output_names);
  19. compiled_metrics = new MetricsContainer(new string[0], output_names: output_names);
  20. int experimental_steps_per_execution = 1;
  21. _configure_steps_per_execution(experimental_steps_per_execution);
  22. // Initialize cache attrs.
  23. _reset_compile_cache();
  24. _is_compiled = true;
  25. }
  26. public void compile(IOptimizer optimizer,
  27. ILossFunc loss,
  28. string[] metrics)
  29. {
  30. this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs
  31. {
  32. });
  33. this.loss = loss ?? new MeanSquaredError();
  34. compiled_loss = new LossesContainer(this.loss, output_names: output_names);
  35. compiled_metrics = new MetricsContainer(metrics, output_names: output_names);
  36. int experimental_steps_per_execution = 1;
  37. _configure_steps_per_execution(experimental_steps_per_execution);
  38. // Initialize cache attrs.
  39. _reset_compile_cache();
  40. _is_compiled = true;
  41. }
  42. public void compile(string optimizer,
  43. string loss,
  44. string[] metrics)
  45. {
  46. this.optimizer = optimizer switch
  47. {
  48. "rmsprop" => new RMSprop(new RMSpropArgs
  49. {
  50. }),
  51. _ => new RMSprop(new RMSpropArgs
  52. {
  53. })
  54. };
  55. this.loss = loss switch
  56. {
  57. "mse" => new MeanSquaredError(),
  58. "mae" => new MeanAbsoluteError(),
  59. _ => new MeanSquaredError()
  60. };
  61. compiled_loss = new LossesContainer(this.loss, output_names: output_names);
  62. compiled_metrics = new MetricsContainer(metrics, output_names: output_names);
  63. int experimental_steps_per_execution = 1;
  64. _configure_steps_per_execution(experimental_steps_per_execution);
  65. // Initialize cache attrs.
  66. _reset_compile_cache();
  67. _is_compiled = true;
  68. }
  69. public void compile(IOptimizer optimizer,
  70. ILossFunc loss,
  71. IMetricFunc[] metrics)
  72. {
  73. this.optimizer = optimizer ?? new RMSprop(new RMSpropArgs
  74. {
  75. });
  76. this.loss = loss ?? new MeanSquaredError();
  77. compiled_loss = new LossesContainer(this.loss, output_names: output_names);
  78. compiled_metrics = new MetricsContainer(metrics, output_names: output_names);
  79. int experimental_steps_per_execution = 1;
  80. _configure_steps_per_execution(experimental_steps_per_execution);
  81. // Initialize cache attrs.
  82. _reset_compile_cache();
  83. _is_compiled = true;
  84. }
  85. }
  86. }