You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

KerasObjectLoader.cs 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. using Newtonsoft.Json;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text.RegularExpressions;
  6. using Tensorflow.Keras.ArgsDefinition;
  7. using Tensorflow.Keras.Engine;
  8. using ThirdParty.Tensorflow.Python.Keras.Protobuf;
  9. using static Tensorflow.Binding;
  10. namespace Tensorflow.Keras.Saving
  11. {
  12. public class KerasObjectLoader
  13. {
  14. SavedMetadata _metadata;
  15. SavedObjectGraph _proto;
  16. Dictionary<int, string> _node_paths = new Dictionary<int, string>();
  17. Dictionary<int, (Model, int[])> model_layer_dependencies = new Dictionary<int, (Model, int[])>();
  18. List<int> _traversed_nodes_from_config = new List<int>();
  19. public KerasObjectLoader(SavedMetadata metadata, SavedObjectGraph object_graph_def)
  20. {
  21. _metadata = metadata;
  22. _proto = object_graph_def;
  23. _metadata.Nodes.ToList().ForEach(x => _node_paths[x.NodeId] = x.NodePath);
  24. }
  25. /// <summary>
  26. /// Load all layer nodes from the metadata.
  27. /// </summary>
  28. /// <param name="compile"></param>
  29. public void load_layers(bool compile = true)
  30. {
  31. var metric_list = new List<ThirdParty.Tensorflow.Python.Keras.Protobuf.SavedObject>();
  32. foreach (var node_metadata in _metadata.Nodes)
  33. {
  34. if (node_metadata.Identifier == "_tf_keras_metric")
  35. {
  36. metric_list.Add(node_metadata);
  37. continue;
  38. }
  39. _load_layer(node_metadata.NodeId, node_metadata.Identifier, node_metadata.Metadata);
  40. }
  41. }
  42. void _load_layer(int node_id, string identifier, string metadata_json)
  43. {
  44. metadata_json = metadata_json.Replace("\"dtype\": \"float32\"", "\"dtype\": 1");
  45. var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);
  46. _revive_from_config(identifier, metadata, node_id);
  47. }
  48. /// <summary>
  49. /// Revives a layer/model from config, or returns None.
  50. /// </summary>
  51. /// <param name="identifier"></param>
  52. /// <param name="metadata"></param>
  53. /// <param name="node_id"></param>
  54. void _revive_from_config(string identifier, KerasMetaData metadata, int node_id)
  55. {
  56. var obj = _revive_graph_network(identifier, metadata, node_id);
  57. obj = obj ?? _revive_layer_or_model_from_config(metadata, node_id);
  58. _add_children_recreated_from_config(obj, _proto.Nodes[node_id], node_id);
  59. }
  60. Model _revive_graph_network(string identifier, KerasMetaData metadata, int node_id)
  61. {
  62. var config = metadata.Config;
  63. var class_name = metadata.ClassName;
  64. Model model = null;
  65. if (class_name == "Sequential")
  66. {
  67. model = new Sequential(new SequentialArgs
  68. {
  69. Name = config.Name
  70. });
  71. }
  72. else if (class_name == "Functional")
  73. {
  74. throw new NotImplementedException("");
  75. }
  76. if (!metadata.IsGraphNetwork)
  77. return null;
  78. // Record this model and its layers. This will later be used to reconstruct
  79. // the model.
  80. var layers = _get_child_layer_node_ids(node_id);
  81. model_layer_dependencies[node_id] = (model, layers);
  82. return model;
  83. }
  84. Model _revive_layer_or_model_from_config(KerasMetaData metadata, int node_id)
  85. {
  86. var config = metadata.Config;
  87. var class_name = metadata.ClassName;
  88. var shared_object_id = metadata.SharedObjectId;
  89. var must_restore_from_config = metadata.MustRestoreFromConfig;
  90. return null;
  91. }
  92. /// <summary>
  93. /// Returns the node ids of each layer in a Sequential/Functional model.
  94. /// </summary>
  95. /// <param name="node_id"></param>
  96. int[] _get_child_layer_node_ids(int node_id)
  97. {
  98. int num_layers = 0;
  99. Dictionary<int, int> child_layers = new Dictionary<int, int>();
  100. foreach (var child in _proto.Nodes[node_id].Children)
  101. {
  102. var m = Regex.Match(child.LocalName, @"layer-(\d+)");
  103. if (!m.Success)
  104. continue;
  105. var layer_n = int.Parse(m.Groups[1].Value);
  106. num_layers = max(layer_n + 1, num_layers);
  107. child_layers[layer_n] = child.NodeId;
  108. }
  109. var ordered = new List<int>();
  110. foreach (var n in range(num_layers))
  111. {
  112. if (child_layers.ContainsKey(n))
  113. ordered.Add(child_layers[n]);
  114. else
  115. break;
  116. }
  117. return ordered.ToArray();
  118. }
  119. /// <summary>
  120. /// Recursively records objects recreated from config.
  121. /// </summary>
  122. /// <param name="obj"></param>
  123. /// <param name="proto"></param>
  124. /// <param name="node_id"></param>
  125. void _add_children_recreated_from_config(Model obj, SavedObject proto, int node_id)
  126. {
  127. if (_traversed_nodes_from_config.Contains(node_id))
  128. return;
  129. var parent_path = _node_paths[node_id];
  130. _traversed_nodes_from_config.Add(node_id);
  131. if (!obj.Built)
  132. {
  133. var metadata_json = proto.UserObject.Metadata.Replace("\"dtype\": \"float32\"", "\"dtype\": 1");
  134. var metadata = JsonConvert.DeserializeObject<KerasMetaData>(metadata_json);
  135. _try_build_layer(obj, node_id, metadata.BuildInputShape);
  136. }
  137. }
  138. bool _try_build_layer(Model obj, int node_id, TensorShape build_input_shape)
  139. {
  140. if (obj.Built)
  141. return true;
  142. return false;
  143. }
  144. }
  145. }