You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ModelSaveTest.cs 7.9 kB

1 year ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using Tensorflow.Keras.Engine;
  5. using Tensorflow.Keras.Layers;
  6. using Tensorflow.Keras.Models;
  7. using Tensorflow.Keras.Optimizers;
  8. using Tensorflow.Keras.Saving;
  9. using Tensorflow.Keras.UnitTest.Helpers;
  10. using static Tensorflow.Binding;
  11. using static Tensorflow.KerasApi;
  12. namespace Tensorflow.Keras.UnitTest.Model
  13. {
  14. /// <summary>
  15. /// https://www.tensorflow.org/guide/keras/save_and_serialize
  16. /// </summary>
  17. [TestClass]
  18. public class ModelSaveTest : EagerModeTestBase
  19. {
  20. [TestMethod]
  21. public void GetAndFromConfig()
  22. {
  23. var model = GetFunctionalModel();
  24. var config = model.get_config();
  25. Debug.Assert(config is FunctionalConfig);
  26. var new_model = new ModelsApi().from_config(config as FunctionalConfig);
  27. Assert.AreEqual(model.Layers.Count, new_model.Layers.Count);
  28. }
  29. IModel GetFunctionalModel()
  30. {
  31. // Create a simple model.
  32. var inputs = keras.Input(shape: 32);
  33. var dense_layer = keras.layers.Dense(1);
  34. var outputs = dense_layer.Apply(inputs);
  35. return keras.Model(inputs, outputs);
  36. }
  37. [TestMethod]
  38. public void SimpleModelFromAutoCompile()
  39. {
  40. var inputs = tf.keras.layers.Input((28, 28, 1));
  41. var x = tf.keras.layers.Flatten().Apply(inputs);
  42. x = tf.keras.layers.Dense(100, activation: "relu").Apply(x);
  43. x = tf.keras.layers.Dense(units: 10).Apply(x);
  44. var outputs = tf.keras.layers.Softmax(axis: 1).Apply(x);
  45. var model = tf.keras.Model(inputs, outputs);
  46. model.compile(new Adam(0.001f),
  47. tf.keras.losses.SparseCategoricalCrossentropy(),
  48. new string[] { "accuracy" });
  49. var data_loader = new MnistModelLoader();
  50. var num_epochs = 1;
  51. var batch_size = 50;
  52. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  53. {
  54. TrainDir = "mnist",
  55. OneHot = false,
  56. ValidationSize = 58000,
  57. }).Result;
  58. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  59. model.save("./pb_simple_compile", save_format: "tf");
  60. }
  61. [TestMethod]
  62. public void SimpleModelFromSequential()
  63. {
  64. var model = keras.Sequential(new List<ILayer>()
  65. {
  66. tf.keras.layers.InputLayer((28, 28, 1)),
  67. tf.keras.layers.Flatten(),
  68. tf.keras.layers.Dense(100, "relu"),
  69. tf.keras.layers.Dense(10),
  70. tf.keras.layers.Softmax()
  71. });
  72. model.summary();
  73. model.compile(new Adam(0.001f), tf.keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  74. var data_loader = new MnistModelLoader();
  75. var num_epochs = 1;
  76. var batch_size = 50;
  77. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  78. {
  79. TrainDir = "mnist",
  80. OneHot = false,
  81. ValidationSize = 58000,
  82. }).Result;
  83. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  84. model.save("./pb_simple_sequential", save_format: "tf");
  85. }
  86. [TestMethod]
  87. public void AlexnetFromSequential()
  88. {
  89. var model = keras.Sequential(new List<ILayer>()
  90. {
  91. tf.keras.layers.InputLayer((227, 227, 3)),
  92. tf.keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"),
  93. tf.keras.layers.BatchNormalization(),
  94. tf.keras.layers.MaxPooling2D((3, 3), strides:(2, 2)),
  95. tf.keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"),
  96. tf.keras.layers.BatchNormalization(),
  97. tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
  98. tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  99. tf.keras.layers.BatchNormalization(),
  100. tf.keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"),
  101. tf.keras.layers.BatchNormalization(),
  102. tf.keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"),
  103. tf.keras.layers.BatchNormalization(),
  104. tf.keras.layers.MaxPooling2D((3, 3), (2, 2)),
  105. tf.keras.layers.Flatten(),
  106. tf.keras.layers.Dense(4096, activation: "relu"),
  107. tf.keras.layers.Dropout(0.5f),
  108. tf.keras.layers.Dense(4096, activation: "relu"),
  109. tf.keras.layers.Dropout(0.5f),
  110. tf.keras.layers.Dense(1000, activation: "linear"),
  111. tf.keras.layers.Softmax(1)
  112. });
  113. model.compile(tf.keras.optimizers.Adam(), tf.keras.losses.SparseCategoricalCrossentropy(from_logits: true), new string[] { "accuracy" });
  114. var num_epochs = 1;
  115. var batch_size = 8;
  116. var dataset = new RandomDataSet(new Shape(227, 227, 3), 16);
  117. model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs);
  118. model.save("./alexnet_from_sequential", save_format: "tf");
  119. // The saved model can be test with the following python code:
  120. #region alexnet_python_code
  121. //import pathlib
  122. //import tensorflow as tf
  123. //def func(a):
  124. // return -a
  125. //if __name__ == '__main__':
  126. // model = tf.keras.models.load_model("./pb_alex_sequential")
  127. // model.summary()
  128. // num_classes = 5
  129. // batch_size = 128
  130. // img_height = 227
  131. // img_width = 227
  132. // epochs = 100
  133. // dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
  134. // data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True)
  135. // data_dir = pathlib.Path(data_dir)
  136. // train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  137. // data_dir,
  138. // validation_split = 0.2,
  139. // subset = "training",
  140. // seed = 123,
  141. // image_size = (img_height, img_width),
  142. // batch_size = batch_size)
  143. // val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  144. // data_dir,
  145. // validation_split = 0.2,
  146. // subset = "validation",
  147. // seed = 123,
  148. // image_size = (img_height, img_width),
  149. // batch_size = batch_size)
  150. // model.compile(optimizer = 'adam',
  151. // loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True),
  152. // metrics =['accuracy'])
  153. // model.build((None, img_height, img_width, 3))
  154. // history = model.fit(
  155. // train_ds,
  156. // validation_data = val_ds,
  157. // epochs = epochs
  158. // )
  159. #endregion
  160. }
  161. [TestMethod]
  162. public void SaveAfterLoad()
  163. {
  164. var model = tf.keras.models.load_model(@"Assets/simple_model_from_auto_compile");
  165. model.summary();
  166. model.save("Assets/saved_auto_compile_after_loading");
  167. //model = tf.keras.models.load_model(@"Assets/saved_auto_compile_after_loading");
  168. //model.summary();
  169. }
  170. }
  171. }