From 6c07778243fb0bc8ab6d209e33a87703db10bee1 Mon Sep 17 00:00:00 2001 From: AsakusaRinne Date: Fri, 3 Feb 2023 20:37:25 +0800 Subject: [PATCH] Add two simple sequential test case of pb model save. --- src/TensorFlowNET.Keras/Engine/Model.Save.cs | 2 +- .../SequentialModelTest.cs} | 66 +++++++++---------- 2 files changed, 34 insertions(+), 34 deletions(-) rename test/TensorFlowNET.Keras.UnitTest/{SaveTest.cs => SaveModel/SequentialModelTest.cs} (53%) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Save.cs b/src/TensorFlowNET.Keras/Engine/Model.Save.cs index 85da920e..a1e891f9 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Save.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Save.cs @@ -25,7 +25,7 @@ namespace Tensorflow.Keras.Engine ConcreteFunction? signatures = null, bool save_traces = true) { - if (save_format != "pb") + if (save_format != "tf") { saver.save(this, filepath); } diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs similarity index 53% rename from test/TensorFlowNET.Keras.UnitTest/SaveTest.cs rename to test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs index 90d0a48a..288a92b3 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelTest.cs @@ -17,18 +17,13 @@ using Tensorflow.Keras.Metrics; using Tensorflow.Keras.Optimizers; using Tensorflow.Operations; -namespace TensorFlowNET.Keras.UnitTest; - -public static class AutoGraphExtension -{ - -} +namespace TensorFlowNET.Keras.UnitTest.SaveModel; [TestClass] -public class SaveTest +public class SequentialModelTest { [TestMethod] - public void Test() + public void SimpleModelFromAutoCompile() { var inputs = new KerasInterface().Input((28, 28, 1)); var x = new Flatten(new FlattenArgs()).Apply(inputs); @@ -36,10 +31,8 @@ public class SaveTest x = new LayersApi().Dense(units: 10).Apply(x); var outputs = new LayersApi().Softmax(axis: 1).Apply(x); var model = new KerasInterface().Model(inputs, outputs); - - model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"}); - var g = ops.get_default_graph(); + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); var data_loader = new MnistModelLoader(); var num_epochs = 1; @@ -49,34 +42,41 @@ public class SaveTest { TrainDir = "mnist", OneHot = false, - ValidationSize = 50000, + ValidationSize = 10000, }).Result; - + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); - - model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb"); + + model.save("C:\\Work\\tf.net\\tf_test\\tf.net.simple.compile", save_format: "tf"); } [TestMethod] - public void Temp() + public void SimpleModelFromSequential() { - var graph = new Graph(); - var g = graph.as_default(); - //var input_tensor = array_ops.placeholder(TF_DataType.TF_FLOAT, new int[] { 1 }, "test_string_tensor"); - var input_tensor = tf.placeholder(tf.int32, new int[] { 1 }, "aa"); - var wrapped_func = tf.autograph.to_graph(func); - var res = wrapped_func(input_tensor); - g.Exit(); - } + Model model = KerasApi.keras.Sequential(new List() + { + keras.layers.InputLayer((28, 28, 1)), + keras.layers.Flatten(), + keras.layers.Dense(100, "relu"), + keras.layers.Dense(10), + keras.layers.Softmax(1) + }); - private Tensor func(Tensor tensor) - { - return gen_ops.neg(tensor); - //return array_ops.identity(tensor); - //tf.device("cpu:0"); - //using (ops.control_dependencies(new object[] { res.op })) - //{ - // return array_ops.identity(tensor); - //} + model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" }); + + var data_loader = new MnistModelLoader(); + var num_epochs = 1; + var batch_size = 50; + + var dataset = data_loader.LoadAsync(new ModelLoadSetting + { + TrainDir = "mnist", + OneHot = false, + ValidationSize = 10000, + }).Result; + + model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs); + + model.save("C:\\Work\\tf.net\\tf_test\\tf.net.simple.sequential", save_format: "tf"); } } \ No newline at end of file