You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Functional.GetConfig.cs 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow.Keras.Layers;
  6. using Tensorflow.Keras.Saving;
  7. using Tensorflow.Keras.Utils;
  8. using static Tensorflow.Binding;
  9. namespace Tensorflow.Keras.Engine
  10. {
  11. public partial class Functional
  12. {
  13. public override IKerasConfig get_config()
  14. {
  15. return get_network_config();
  16. }
  17. /// <summary>
  18. /// Builds the config, which consists of the node graph and serialized layers.
  19. /// </summary>
  20. FunctionalConfig get_network_config()
  21. {
  22. var config = new FunctionalConfig
  23. {
  24. Name = name
  25. };
  26. var node_conversion_map = new Dictionary<string, int>();
  27. foreach (var layer in _self_tracked_trackables)
  28. {
  29. var kept_nodes = _should_skip_first_node(layer) ? 1 : 0;
  30. foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
  31. {
  32. var node_key = _make_node_key(layer.Name, original_node_index);
  33. if (NetworkNodes.Contains(node_key))
  34. {
  35. node_conversion_map[node_key] = kept_nodes;
  36. kept_nodes += 1;
  37. }
  38. }
  39. }
  40. var layer_configs = new List<LayerConfig>();
  41. using (SharedObjectSavingScope.Enter())
  42. {
  43. foreach (var layer in _self_tracked_trackables)
  44. {
  45. var filtered_inbound_nodes = new List<NodeConfig>();
  46. foreach (var (original_node_index, node) in enumerate(layer.InboundNodes))
  47. {
  48. var node_key = _make_node_key(layer.Name, original_node_index);
  49. if (NetworkNodes.Contains(node_key) && !node.is_input)
  50. {
  51. var node_data = node.serialize(_make_node_key, node_conversion_map);
  52. filtered_inbound_nodes.append(node_data);
  53. }
  54. }
  55. var layer_config = generic_utils.serialize_layer_to_config(layer);
  56. layer_config.Name = layer.Name;
  57. layer_config.InboundNodes = filtered_inbound_nodes;
  58. layer_configs.Add(layer_config);
  59. }
  60. }
  61. config.Layers = layer_configs;
  62. // Gather info about inputs and outputs.
  63. var model_inputs = new List<NodeConfig>();
  64. foreach (var i in range(_input_layers.Count))
  65. {
  66. var (layer, node_index, tensor_index) = _input_coordinates[i];
  67. var node_key = _make_node_key(layer.Name, node_index);
  68. if (!NetworkNodes.Contains(node_key))
  69. continue;
  70. var new_node_index = node_conversion_map[node_key];
  71. model_inputs.append(new NodeConfig
  72. {
  73. Name = layer.Name,
  74. NodeIndex = new_node_index,
  75. TensorIndex = tensor_index
  76. });
  77. }
  78. config.InputLayers = model_inputs;
  79. var model_outputs = new List<NodeConfig>();
  80. foreach (var i in range(_output_layers.Count))
  81. {
  82. var (layer, node_index, tensor_index) = _output_coordinates[i];
  83. var node_key = _make_node_key(layer.Name, node_index);
  84. if (!NetworkNodes.Contains(node_key))
  85. continue;
  86. var new_node_index = node_conversion_map[node_key];
  87. model_outputs.append(new NodeConfig
  88. {
  89. Name = layer.Name,
  90. NodeIndex = new_node_index,
  91. TensorIndex = tensor_index
  92. });
  93. }
  94. config.OutputLayers = model_outputs;
  95. return config;
  96. }
  97. string _make_node_key(string layer_name, int node_index)
  98. => $"{layer_name}_ib-{node_index}";
  99. }
  100. }