You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelSaveTest.cs 8.3 kB

1 year ago
1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using Tensorflow.Keras.Engine;
  5. using Tensorflow.Keras.Layers;
  6. using Tensorflow.Keras.Models;
  7. using Tensorflow.Keras.Optimizers;
  8. using Tensorflow.Keras.Saving;
  9. using Tensorflow.Keras.UnitTest.Helpers;
  10. using static Tensorflow.Binding;
  11. using static Tensorflow.KerasApi;
  12. namespace Tensorflow.Keras.UnitTest.Model
  13. {
  14. /// <summary>
  15. /// https://www.tensorflow.org/guide/keras/save_and_serialize
  16. /// </summary>
  17. [TestClass]
  18. public class ModelSaveTest : EagerModeTestBase
  19. {
  20. [TestMethod]
  21. public void GetAndFromConfig()
  22. {
  23. var model = GetFunctionalModel();
  24. var config = model.get_config();
  25. Debug.Assert(config is FunctionalConfig);
  26. var new_model = new ModelsApi().from_config(config as FunctionalConfig);
  27. Assert.AreEqual(model.Layers.Count, new_model.Layers.Count);
  28. }
  29. IModel GetFunctionalModel()
  30. {
  31. // Create a simple model.
  32. var inputs = keras.Input(shape: 32);
  33. var dense_layer = keras.layers.Dense(1);
  34. var outputs = dense_layer.Apply(inputs);
  35. return keras.Model(inputs, outputs);
  36. }
  37. [TestMethod]
  38. public void SimpleModelFromAutoCompile()
  39. {
  40. var inputs = tf.keras.layers.Input((28, 28, 1));
  41. var x = tf.keras.layers.Flatten().Apply(inputs);
  42. x = tf.keras.layers.Dense(100, activation: "relu").Apply(x);
  43. x = tf.keras.layers.Dense(units: 10).Apply(x);
  44. var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x);
  45. var model = tf.keras.Model(inputs, outputs);
  46. model.compile(new Adam(0.001f),
  47. tf.keras.losses.SparseCategoricalCrossentropy(),
  48. new string[] { "accuracy" });
  49. var data_loader = new MnistModelLoader();
  50. var num_epochs = 1;
  51. var batch_size = 50;
  52. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  53. {
  54. TrainDir = "mnist",
  55. OneHot = false,
  56. ValidationSize = 58000,
  57. }).Result;
  58. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  59. model.save("./pb_simple_compile", save_format: "tf");
  60. }
  61. [TestMethod]
  62. public void SimpleModelFromSequential()
  63. {
  64. var model = keras.Sequential(new List<ILayer>()
  65. {
  66. tf.keras.layers.InputLayer((28, 28, 1)),
  67. tf.keras.layers.Flatten(),
  68. tf.keras.layers.Dense(100, "relu"),
  69. tf.keras.layers.Dense(10),
  70. tf.keras.layers.Softmax()
  71. });
  72. model.summary();
  73. model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  74. var data_loader = new MnistModelLoader();
  75. var num_epochs = 1;
  76. var batch_size = 50;
  77. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  78. {
  79. TrainDir = "mnist",
  80. OneHot = false,
  81. ValidationSize = 58000,
  82. }).Result;
  83. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  84. model.save("./pb_simple_sequential", save_format: "tf");
  85. }
  86. [TestMethod]
  87. public void AlexnetFromSequential()
  88. {
  89. var model = keras.Sequential(new List<ILayer>()
  90. {
  91. tf.keras.layers.InputLayer((227, 227, 3)),
  92. tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
  93. tf.keras.layers.BatchNormalization(),
  94. tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),
  95. tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1L2),
  96. tf.keras.layers.BatchNormalization(),
  97. tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L2),
  98. tf.keras.layers.BatchNormalization(),
  99. tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: keras.activations.Relu, bias_regularizer:keras.regularizers.L1),
  100. tf.keras.layers.BatchNormalization(),
  101. tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
  102. tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  103. tf.keras.layers.BatchNormalization(),
  104. tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  105. tf.keras.layers.BatchNormalization(),
  106. tf.keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"),
  107. tf.keras.layers.BatchNormalization(),
  108. tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
  109. tf.keras.layers.Flatten(),
  110. tf.keras.layers.Dense(4096, activation: "relu"),
  111. tf.keras.layers.Dropout(0.5f),
  112. tf.keras.layers.Dense(4096, activation: "relu"),
  113. tf.keras.layers.Dropout(0.5f),
  114. tf.keras.layers.Dense(1000, activation: "linear"),
  115. tf.keras.layers.Softmax(1)
  116. });
  117. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  118. var num_epochs = 1;
  119. var batch_size = 8;
  120. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  121. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  122. model.save("./alexnet_from_sequential", save_format: "tf");
  123. // The saved model can be test with the following python code:
  124. #region alexnet_python_code
  125. //import pathlib
  126. //import tensorflow as tf
  127. //def func(a):
  128. // return -a
  129. //if __name__ == '__main__':
  130. // model = tf.keras.models.load_model("./pb_alex_sequential")
  131. // model.summary()
  132. // num_classes = 5
  133. // batch_size = 128
  134. // img_height = 227
  135. // img_width = 227
  136. // epochs = 100
  137. // dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
  138. // data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True)
  139. // data_dir = pathlib.Path(data_dir)
  140. // train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  141. // data_dir,
  142. // validation_split = 0.2,
  143. // subset = "training",
  144. // seed = 123,
  145. // image_size = (img_height, img_width),
  146. // batch_size = batch_size)
  147. // val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  148. // data_dir,
  149. // validation_split = 0.2,
  150. // subset = "validation",
  151. // seed = 123,
  152. // image_size = (img_height, img_width),
  153. // batch_size = batch_size)
  154. // model.compile(optimizer = 'adam',
  155. // loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
  156. // metrics =['accuracy'])
  157. // model.build((None, img_height, img_width, 3))
  158. // history = model.fit(
  159. // train_ds,
  160. // validation_data = val_ds,
  161. // epochs = epochs
  162. // )
  163. #endregion
  164. }
  165. [TestMethod]
  166. public void SaveAfterLoad()
  167. {
  168. var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  169. model.summary();
  170. model.save("Assets/saved_auto_compile_after_loading");
  171. //model = tf.keras.models.load_model(@"Assets/saved_auto_compile_after_loading");
  172. //model.summary();
  173. }
  174. }
  175. }