From 88fe402b7fa39d180cfa6385b4493af82c2e7be2 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Sat, 4 Feb 2023 21:04:00 +0800 Subject: [PATCH] Add alexnet pb save test. --- .../SaveModel/SequentialModelTest.cs | 202 ++++++++++++++++++ 1 file changed, 202 insertions(+) create mode 100644 test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs new file mode 100644 index 00000000..c3145344 --- /dev/null +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs @@ -0,0 +1,202 @@ +using Microsoft.VisualStudio.TestTools.UnitTesting; +using Tensorflow.NumPy; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Tensorflow; +using static Tensorflow.Binding; +using static Tensorflow.KerasApi; +using Tensorflow.Keras; +using Tensorflow.Keras.ArgsDefinition; +using Tensorflow.Keras.Engine; +using Tensorflow.Keras.Layers; +using Tensorflow.Keras.Losses; +using Tensorflow.Keras.Metrics; +using Tensorflow.Keras.Optimizers; +using Tensorflow.Operations; +using System.Diagnostics; + +namespace TensorFlowNET.Keras.UnitTest.SaveModel; + +[TestClass] +public class SequentialModelTest +{ + [TestMethod] + public void SimpleModelFromAutoCompile() + { + var inputs = new KerasInterface().Input((28, 28, 1)); + var x = new Flatten(new FlattenArgs()).Apply(inputs); + x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x); + x = new LayersApi().Dense(units: 10).Apply(x); + var outputs = new LayersApi().Softmax(axis: 1).Apply(x); + var model = new KerasInterface().Model(inputs, outputs); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 10000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("./pb_simple_compile", save_format: "tf"); + } + + [TestMethod] + public void SimpleModelFromSequential() + { + Model model = KerasApi.keras.Sequential(new List() + { + keras.layers.InputLayer((28, 28, 1)), + keras.layers.Flatten(), + keras.layers.Dense(100, "relu"), + keras.layers.Dense(10), + keras.layers.Softmax(1) + }); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 10000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("./pb_simple_sequential", save_format: "tf"); + } + + [TestMethod] + public void AlexModelFromSequential() + { + Model model = KerasApi.keras.Sequential(new List() + { + keras.layers.InputLayer((227, 227, 3)), + keras.layers.Conv2D(96, (11, 11), (4, 4), activation:"relu", padding:"valid"), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), strides:(2, 2)), + + keras.layers.Conv2D(256, (5, 5), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), (2, 2)), + + keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + + keras.layers.Conv2D(384, (3, 3), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + + keras.layers.Conv2D(256, (3, 3), (1, 1), "same", activation: "relu"), + keras.layers.BatchNormalization(), + keras.layers.MaxPooling2D((3, 3), (2, 2)), + + keras.layers.Flatten(), + keras.layers.Dense(4096, activation: "relu"), + keras.layers.Dropout(0.5f), + + keras.layers.Dense(4096, activation: "relu"), + keras.layers.Dropout(0.5f), + + keras.layers.Dense(1000, activation: "linear"), + keras.layers.Softmax(1) + }); + + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(from_logits:true), new string[] { "accuracy" }); + + var num_epochs = 1; + var batch_size = 16; + + var dataset = new RandomDataSet(new Shape(227, 227, 3), 16); + + model.fit(dataset.Data, dataset.Labels, batch_size, num_epochs); + + model.save("./pb_elex_sequential", save_format: "tf"); + + // The saved model can be test with the following python code: + #region alexnet_python_code + //import pathlib + //import tensorflow as tf + + //def func(a): + // return -a + + //if __name__ == '__main__': + // model = tf.keras.models.load_model("./pb_elex_sequential") + // model.summary() + + // num_classes = 5 + // batch_size = 128 + // img_height = 227 + // img_width = 227 + // epochs = 100 + + // dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz" + // data_dir = tf.keras.utils.get_file('flower_photos', origin = dataset_url, untar = True) + // data_dir = pathlib.Path(data_dir) + + // train_ds = tf.keras.preprocessing.image_dataset_from_directory( + // data_dir, + // validation_split = 0.2, + // subset = "training", + // seed = 123, + // image_size = (img_height, img_width), + // batch_size = batch_size) + + // val_ds = tf.keras.preprocessing.image_dataset_from_directory( + // data_dir, + // validation_split = 0.2, + // subset = "validation", + // seed = 123, + // image_size = (img_height, img_width), + // batch_size = batch_size) + + + // model.compile(optimizer = 'adam', + // loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True), + // metrics =['accuracy']) + + // model.build((None, img_height, img_width, 3)) + + // history = model.fit( + // train_ds, + // validation_data = val_ds, + // epochs = epochs + // ) + #endregion + } + + public class RandomDataSet : DataSetBase + { + private Shape _shape; + + public RandomDataSet(Shape shape, int count) + { + _shape = shape; + Debug.Assert(_shape.ndim == 3); + long[] dims = new long[4]; + dims[0] = count; + for (int i = 1; i < 4; i++) + { + dims[i] = _shape[i - 1]; + } + Shape s = new Shape(dims); + Data = np.random.normal(0, 2, s); + Labels = np.random.uniform(0, 1, (count, 1)); + } + } +} \ No newline at end of file