You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SequentialModelLoad.cs 1.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Threading.Tasks;
  8. using Tensorflow.Keras.Engine;
  9. using Tensorflow.Keras.Saving.SavedModel;
  10. using Tensorflow.Keras.Losses;
  11. using Tensorflow.Keras.Metrics;
  12. using Tensorflow;
  13. using Tensorflow.Keras.Optimizers;
  14. using static Tensorflow.KerasApi;
  15. namespace TensorFlowNET.Keras.UnitTest.SaveModel;
  16. [TestClass]
  17. public class SequentialModelLoad
  18. {
  19. [TestMethod]
  20. public void SimpleModelFromSequential()
  21. {
  22. //new SequentialModelSave().SimpleModelFromSequential();
  23. var model = keras.models.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential");
  24. model.summary();
  25. model.compile(new Adam(0.0001f), new LossesApi().SparseCategoricalCrossentropy(), new string[] { "accuracy" });
  26. var data_loader = new MnistModelLoader();
  27. var num_epochs = 1;
  28. var batch_size = 8;
  29. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  30. {
  31. TrainDir = "mnist",
  32. OneHot = false,
  33. ValidationSize = 50000,
  34. }).Result;
  35. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  36. model.summary();
  37. }
  38. }