You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SaveTest.cs 2.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Linq;
  6. using System.Text;
  7. using System.Threading.Tasks;
  8. using Tensorflow;
  9. using static Tensorflow.Binding;
  10. using static Tensorflow.KerasApi;
  11. using Tensorflow.Keras;
  12. using Tensorflow.Keras.ArgsDefinition;
  13. using Tensorflow.Keras.Engine;
  14. using Tensorflow.Keras.Layers;
  15. using Tensorflow.Keras.Losses;
  16. using Tensorflow.Keras.Metrics;
  17. using Tensorflow.Keras.Optimizers;
  18. using Tensorflow.Operations;
  19. namespace TensorFlowNET.Keras.UnitTest;
  20. public static class AutoGraphExtension
  21. {
  22. }
  23. [TestClass]
  24. public class SaveTest
  25. {
  26. [TestMethod]
  27. public void Test()
  28. {
  29. var inputs = new KerasInterface().Input((28, 28, 1));
  30. var x = new Flatten(new FlattenArgs()).Apply(inputs);
  31. x = new Dense(new DenseArgs() { Units = 100, Activation = tf.nn.relu }).Apply(x);
  32. x = new LayersApi().Dense(units: 10).Apply(x);
  33. var outputs = new LayersApi().Softmax(axis: 1).Apply(x);
  34. var model = new KerasInterface().Model(inputs, outputs);
  35. model.compile(new Adam(0.001f), new LossesApi().SparseCategoricalCrossentropy(), new string[]{"accuracy"});
  36. var g = ops.get_default_graph();
  37. var data_loader = new MnistModelLoader();
  38. var num_epochs = 1;
  39. var batch_size = 50;
  40. var dataset = data_loader.LoadAsync(new ModelLoadSetting
  41. {
  42. TrainDir = "mnist",
  43. OneHot = false,
  44. ValidationSize = 50000,
  45. }).Result;
  46. model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
  47. model.save("C:\\Work\\tf.net\\tf_test\\tf.net.model", save_format:"pb");
  48. }
  49. [TestMethod]
  50. public void Temp()
  51. {
  52. var graph = new Graph();
  53. var g = graph.as_default();
  54. //var input_tensor = array_ops.placeholder(TF_DataType.TF_FLOAT, new int[] { 1 }, "test_string_tensor");
  55. var input_tensor = tf.placeholder(tf.int32, new int[] { 1 }, "aa");
  56. var wrapped_func = tf.autograph.to_graph(func);
  57. var res = wrapped_func(input_tensor);
  58. g.Exit();
  59. }
  60. private Tensor func(Tensor tensor)
  61. {
  62. return gen_ops.neg(tensor);
  63. //return array_ops.identity(tensor);
  64. //tf.device("cpu:0");
  65. //using (ops.control_dependencies(new object[] { res.op }))
  66. //{
  67. // return array_ops.identity(tensor);
  68. //}
  69. }
  70. }