You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Save.cs 3.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. using System;
  2. using System.Collections.Generic;
  3. using System.IO;
  4. using System.Linq;
  5. using Google.Protobuf;
  6. using ICSharpCode.SharpZipLib.Zip;
  7. using Tensorflow.Checkpoint;
  8. using Tensorflow.Contexts;
  9. using Tensorflow.Functions;
  10. using Tensorflow.Keras.Engine;
  11. using Tensorflow.Keras.Utils;
  12. using Tensorflow.ModelSaving;
  13. using Tensorflow.Train;
  14. using Tensorflow.Exceptions;
  15. using Tensorflow.IO;
  16. using Tensorflow.Keras.Optimizers;
  17. using ThirdParty.Tensorflow.Python.Keras.Protobuf;
  18. using static Tensorflow.Binding;
  19. namespace Tensorflow.Keras.Saving.SavedModel;
  20. public partial class KerasSavedModelUtils
  21. {
  22. public static void Save(Model model, string filepath, bool overwrite, bool include_optimizer, ConcreteFunction? signatures,
  23. SaveOptions? options, bool save_traces = true)
  24. {
  25. if (!overwrite && File.Exists(filepath))
  26. {
  27. throw new Exception("The file already exists but is not allowed to overwrite it.");
  28. }
  29. if (save_traces)
  30. {
  31. if(should_skip_serialization(model))
  32. {
  33. throw new NotImplementedException();
  34. }
  35. }
  36. OptimizerV2? orig_optimizer = null;
  37. if (!include_optimizer)
  38. {
  39. orig_optimizer = model.Optimizer;
  40. model.Optimizer = null;
  41. model._delete_tracking("optimizer");
  42. }
  43. IList<Trackable> saved_nodes;
  44. IDictionary<Trackable, IEnumerable<TrackableReference>> node_paths;
  45. // skip two scopes of python
  46. using (KerasSavedModelUtils.keras_option_scope(save_traces))
  47. {
  48. (saved_nodes, node_paths) = Tensorflow.SavedModelUtils.save_and_return_nodes(model, filepath, signatures, options);
  49. }
  50. var metadata = generate_keras_metadata(saved_nodes, node_paths);
  51. File.WriteAllBytes(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToByteArray());
  52. //File.WriteAllText(Path.Combine(filepath, Constants.SAVED_METADATA_PATH), metadata.ToString());
  53. if (!include_optimizer)
  54. {
  55. model.Optimizer = orig_optimizer!;
  56. }
  57. }
  58. public static SavedMetadata generate_keras_metadata(IList<Trackable> saved_nodes,
  59. IDictionary<Trackable, IEnumerable<TrackableReference>> node_paths)
  60. {
  61. var metadata = new SavedMetadata();
  62. for (int i = 0; i < saved_nodes.Count; i++)
  63. {
  64. var node = saved_nodes[i];
  65. if (node is not Layer)
  66. {
  67. continue;
  68. }
  69. Layer layer = (Layer)node;
  70. var path = node_paths[node];
  71. string node_path;
  72. if (path is null || path.Count() == 0)
  73. {
  74. node_path = "root";
  75. }
  76. else
  77. {
  78. node_path = $"root.{string.Join(".", path.Select(x => x.Name))}";
  79. }
  80. ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject saved_object = new()
  81. {
  82. NodeId = i,
  83. NodePath = node_path,
  84. Version = new ThirdParty.Tensorflow.Python.Keras.Protobuf.VersionDef()
  85. {
  86. Producer = 2,
  87. MinConsumer = 1,
  88. BadConsumers = { }
  89. },
  90. Identifier = layer.ObjectIdentifier,
  91. Metadata = layer.TrackingMetadata
  92. };
  93. metadata.Nodes.Add(saved_object);
  94. }
  95. return metadata;
  96. }
  97. }