You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

layer_serialization.cs 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. using System.Collections.Generic;
  2. using Newtonsoft.Json;
  3. using Newtonsoft.Json.Linq;
  4. using Tensorflow.Keras.Engine;
  5. using Tensorflow.Keras.Utils;
  6. using Tensorflow.Train;
  7. namespace Tensorflow.Keras.Saving.SavedModel;
  8. public class LayerSavedModelSaver: SavedModelSaver
  9. {
  10. private Layer _obj;
  11. public LayerSavedModelSaver(Layer obj): base(obj)
  12. {
  13. _obj = obj;
  14. }
  15. public override string ObjectIdentifier
  16. {
  17. get => Constants.LAYER_IDENTIFIER;
  18. }
  19. public override IDictionary<string, Trackable> objects_to_serialize(IDictionary<string, object> serialization_cache)
  20. {
  21. return get_serialized_attributes(serialization_cache).ObjectsToSerialize;
  22. }
  23. public override IDictionary<string, Trackable> functions_to_serialize(IDictionary<string, object> serialization_cache)
  24. {
  25. return get_serialized_attributes(serialization_cache).FunctionsToSerialize;
  26. }
  27. /// <summary>
  28. /// Generates or retrieves serialized attributes from cache.
  29. /// </summary>
  30. /// <param name="serialization_cache"></param>
  31. protected SerializedAttributes get_serialized_attributes(IDictionary<string, object> serialization_cache)
  32. {
  33. // TODO: deal with cache.
  34. var serialized_attr = SerializedAttributes.Create(_obj);
  35. // TODO: complete the statement. Currently the `Layer` lacks member `_must_restore_from_config`.
  36. if (KerasSavedModelUtils.should_skip_serialization(_obj))
  37. {
  38. return serialized_attr;
  39. }
  40. var (object_dict, function_dict) = get_serialized_attributes_internal(serialization_cache);
  41. serialized_attr.set_and_validate_objects(object_dict);
  42. serialized_attr.set_and_validate_functions(function_dict);
  43. return serialized_attr;
  44. }
  45. /// <summary>
  46. /// Returns dictionary of serialized attributes.
  47. /// </summary>
  48. /// <param name="serialization_cache"></param>
  49. private (IDictionary<string, Trackable>, IDictionary<string, Trackable>) get_serialized_attributes_internal(IDictionary<string, object> serialization_cache)
  50. {
  51. var objects = KerasSavedModelUtils.wrap_layer_objects(_obj, serialization_cache);
  52. var functions = KerasSavedModelUtils.wrap_layer_functions(_obj, serialization_cache);
  53. functions["_default_save_signature"] = null;
  54. return (objects, functions);
  55. }
  56. public override string TrackingMetadata
  57. {
  58. get
  59. {
  60. JObject metadata = new JObject();
  61. metadata["name"] = _obj.Name;
  62. metadata["trainable"] = _obj.Trainable;
  63. // metadata["expects_training_arg"] = _obj._expects_training_arg;
  64. // metadata["dtype"] = policy.serialize(_obj._dtype_policy)
  65. metadata["batch_input_shape"] = JToken.FromObject(_obj.BatchInputShape);
  66. // metadata["stateful"] = _obj.stateful;
  67. // metadata["must_restore_from_config"] = _obj.must_restore_from_config;
  68. // metadata["preserve_input_structure_in_config"] = _obj.preserve_input_structure_in_config;
  69. metadata["autocast"] = _obj.AutoCast;
  70. metadata.Merge(JObject.FromObject(get_serialized(_obj)), new JsonMergeSettings
  71. {
  72. // Handle conflicts by using values from obj2
  73. MergeArrayHandling = MergeArrayHandling.Merge
  74. });
  75. // skip the check of `input_spec` and `build_input_shape` for the lack of members.
  76. // skip the check of `activity_regularizer` for the type problem.
  77. return metadata.ToString();
  78. }
  79. }
  80. public static LayerConfig get_serialized(Layer obj)
  81. {
  82. return generic_utils.serialize_keras_object(obj);
  83. }
  84. }