You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Functional.cs 15 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using Tensorflow.Keras.ArgsDefinition;
  5. using Tensorflow.Keras.Utils;
  6. using static Tensorflow.Binding;
  7. namespace Tensorflow.Keras.Engine
  8. {
  9. /// <summary>
  10. /// A `Functional` model is a `Model` defined as a directed graph of layers.
  11. /// </summary>
  12. public partial class Functional : Model
  13. {
  14. TensorShape _build_input_shape;
  15. bool _compute_output_and_mask_jointly;
  16. bool _expects_training_arg;
  17. bool _expects_mask_arg;
  18. bool _autocast;
  19. List<ILayer> _output_layers;
  20. List<ILayer> _input_layers;
  21. List<KerasHistory> _input_coordinates;
  22. List<KerasHistory> _output_coordinates;
  23. public string[] NetworkNodes { get; set; }
  24. Dictionary<int, int> tensor_usage_count;
  25. public Dictionary<int, int> TensorUsageCount => tensor_usage_count;
  26. public Functional(Tensors inputs, Tensors outputs, string name = null)
  27. : base(new ModelArgs
  28. {
  29. Name = name,
  30. Inputs = inputs,
  31. Outputs = outputs
  32. })
  33. {
  34. _input_layers = new List<ILayer>();
  35. _output_layers = new List<ILayer>();
  36. _input_coordinates = new List<KerasHistory>();
  37. _output_coordinates = new List<KerasHistory>();
  38. tensor_usage_count = new Dictionary<int, int>();
  39. if (this is Sequential)
  40. return;
  41. _init_graph_network(inputs, outputs);
  42. }
  43. protected void _init_graph_network(Tensors inputs, Tensors outputs)
  44. {
  45. _is_graph_network = true;
  46. this.inputs = inputs;
  47. this.outputs = outputs;
  48. built = true;
  49. _build_input_shape = inputs.shape;
  50. _compute_output_and_mask_jointly = true;
  51. _expects_training_arg = true;
  52. _expects_mask_arg = true;
  53. // A graph network does not autocast inputs, as its layers will cast them instead.
  54. _autocast = false;
  55. if (outputs.Any(x => x.KerasHistory == null))
  56. base_layer_utils.create_keras_history(outputs);
  57. // Build self._output_layers:
  58. foreach (var x in outputs)
  59. {
  60. var (layer, node_index, tensor_index) = x.KerasHistory;
  61. _output_layers.append(layer);
  62. _output_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x));
  63. }
  64. // Build self._input_layers:
  65. foreach (var x in inputs)
  66. {
  67. var (layer, node_index, tensor_index) = x.KerasHistory;
  68. _input_layers.append(layer);
  69. _input_coordinates.append(new KerasHistory(layer, node_index, tensor_index, x));
  70. }
  71. // Keep track of the network's nodes and layers.
  72. var (nodes, nodes_by_depth, layers, _) = MapGraphNetwork(inputs, outputs);
  73. NetworkNodes = nodes;
  74. NodesByDepth = nodes_by_depth;
  75. _layers = layers;
  76. // Build self.input_names and self.output_names.
  77. _set_output_names();
  78. ComputeTensorUsageCount();
  79. }
  80. /// <summary>
  81. /// Assigns unique names to the Network's outputs.
  82. /// </summary>
  83. void _set_output_names()
  84. {
  85. var uniquified = new List<string>();
  86. var output_names = new List<string>();
  87. var prefix_count = new Dictionary<string, int>();
  88. foreach (var layer in _output_layers)
  89. {
  90. var proposal = layer.Name;
  91. while (output_names.Contains(proposal))
  92. {
  93. var existing_count = prefix_count.Get(layer.Name, 1);
  94. proposal = $"{layer.Name}_{existing_count}";
  95. prefix_count[layer.Name] = existing_count + 1;
  96. }
  97. output_names.add(proposal);
  98. uniquified.append(proposal);
  99. }
  100. this.output_names = uniquified.ToArray();
  101. }
  102. void ComputeTensorUsageCount()
  103. {
  104. var available_tensors = inputs.Select(x => x.GetHashCode()).ToList();
  105. var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().Skip(1).ToArray();
  106. foreach (var depth in depth_keys)
  107. {
  108. foreach (var node in NodesByDepth[depth])
  109. {
  110. var input_tensors = node.KerasInputs.Select(x => x.GetHashCode()).ToArray();
  111. if (input_tensors.issubset(available_tensors))
  112. {
  113. foreach (var tensor in node.KerasInputs)
  114. {
  115. if (!tensor_usage_count.ContainsKey(tensor.GetHashCode()))
  116. tensor_usage_count[tensor.GetHashCode()] = 0;
  117. tensor_usage_count[tensor.GetHashCode()] += 1;
  118. }
  119. foreach (var output_tensor in node.Outputs)
  120. available_tensors.Add(output_tensor.GetHashCode());
  121. }
  122. }
  123. }
  124. foreach (var tensor in outputs)
  125. {
  126. if (!tensor_usage_count.ContainsKey(tensor.GetHashCode()))
  127. tensor_usage_count[tensor.GetHashCode()] = 0;
  128. tensor_usage_count[tensor.GetHashCode()] += 1;
  129. }
  130. }
  131. /// <summary>
  132. /// Validates a network's topology and gather its layers and nodes.
  133. /// </summary>
  134. /// <param name="inputs"></param>
  135. /// <param name="outputs"></param>
  136. (string[], Dictionary<int, List<INode>>, List<ILayer>, Dictionary<int, List<ILayer>>) MapGraphNetwork(Tensors inputs, Tensors outputs)
  137. {
  138. var (nodes_in_decreasing_depth, layer_indices) = BuildMap(outputs);
  139. var network_nodes = nodes_in_decreasing_depth
  140. .Select(node => MakeNodeKey(node.Layer.Name, node.Layer.InboundNodes.IndexOf(node)))
  141. .ToArray();
  142. var nodes_depths = new Dictionary<INode, int>();
  143. var layers_depths = new Dictionary<ILayer, int>();
  144. nodes_in_decreasing_depth.Reverse();
  145. foreach (var node in nodes_in_decreasing_depth)
  146. {
  147. // If the depth is not set, the node has no outbound nodes (depth 0).
  148. int depth = nodes_depths.SetDefault(node, 0);
  149. // Update the depth of the corresponding layer
  150. int previous_depth = layers_depths.Get(node.Layer, 0);
  151. // If we've seen this layer before at a higher depth,
  152. // we should use that depth instead of the node depth.
  153. // This is necessary for shared layers that have inputs at different
  154. // depth levels in the graph.
  155. depth = Math.Max(depth, previous_depth);
  156. layers_depths[node.Layer] = depth;
  157. nodes_depths[node] = depth;
  158. // Update the depth of inbound nodes.
  159. // The "depth" of a node is the max of the depths
  160. // of all nodes it is connected to + 1.
  161. foreach (var node_dep in node.ParentNodes)
  162. {
  163. previous_depth = nodes_depths.Get(node_dep, 0);
  164. nodes_depths[node_dep] = Math.Max(depth + 1, previous_depth);
  165. }
  166. }
  167. // Handle inputs that are not connected to outputs.
  168. // We do not error out here because the inputs may be used to compute losses
  169. // and metrics.
  170. foreach (var input_t in inputs)
  171. {
  172. var (input_layer, _, _) = input_t.KerasHistory;
  173. if (!layers_depths.ContainsKey(input_layer))
  174. {
  175. layers_depths[input_layer] = 0;
  176. layer_indices[input_layer] = -1;
  177. nodes_depths[input_layer.InboundNodes[0]] = 0;
  178. network_nodes.add(MakeNodeKey(input_layer.Name, 0));
  179. }
  180. }
  181. // Build a dict {depth: list of nodes with this depth}
  182. var nodes_by_depth = new Dictionary<int, List<INode>>();
  183. foreach (var (node, depth) in enumerate(nodes_depths))
  184. {
  185. if (!nodes_by_depth.ContainsKey(depth))
  186. nodes_by_depth[depth] = new List<INode>();
  187. nodes_by_depth[depth].append(node);
  188. }
  189. var layers_by_depth = new Dictionary<int, List<ILayer>>();
  190. foreach (var (layer, depth) in enumerate(layers_depths))
  191. {
  192. if (!layers_by_depth.ContainsKey(depth))
  193. layers_by_depth[depth] = new List<ILayer>();
  194. layers_by_depth[depth].append(layer);
  195. }
  196. // Get sorted list of layer depths.
  197. var depth_keys = layers_by_depth.Keys.OrderBy(x => x).Reverse();
  198. // Set self.layers ordered by depth.
  199. var layers = new List<ILayer>();
  200. foreach (var depth in depth_keys)
  201. {
  202. var layers_for_depth = layers_by_depth[depth];
  203. // Network.layers needs to have a deterministic order:
  204. // here we order them by traversal order.
  205. layers_for_depth = layers_for_depth.OrderBy(x => layer_indices[x]).ToList();
  206. layers.AddRange(layers_for_depth);
  207. }
  208. // Get sorted list of node depths.
  209. depth_keys = nodes_by_depth.Keys.OrderBy(x => x).Reverse();
  210. return (network_nodes, nodes_by_depth, layers, layers_by_depth);
  211. }
  212. string MakeNodeKey(string layer_name, int node_index)
  213. => $"{layer_name}_ib-{node_index}";
  214. /// <summary>
  215. /// This method topologically sorts nodes in order from inputs to outputs.
  216. /// </summary>
  217. /// <param name="outputs"></param>
  218. (List<INode>, Dictionary<ILayer, int>) BuildMap(Tensors outputs)
  219. {
  220. var finished_nodes = new List<INode>();
  221. var nodes_in_progress = new List<INode>();
  222. var nodes_in_decreasing_depth = new List<INode>();
  223. var layer_indices = new Dictionary<ILayer, int>();
  224. foreach (var output in outputs)
  225. BuildMapHelper(output,
  226. finished_nodes,
  227. nodes_in_progress,
  228. nodes_in_decreasing_depth,
  229. layer_indices);
  230. return (nodes_in_decreasing_depth, layer_indices);
  231. }
  232. void BuildMapHelper(Tensor tensor,
  233. List<INode> finished_nodes,
  234. List<INode> nodes_in_progress,
  235. List<INode> nodes_in_decreasing_depth,
  236. Dictionary<ILayer, int> layer_indices)
  237. {
  238. var (layer, node_index, _) = tensor.KerasHistory;
  239. var node = layer.InboundNodes[node_index] as Node;
  240. // Don't repeat work for shared subgraphs
  241. if (finished_nodes.Contains(node))
  242. return;
  243. // Prevent cycles.
  244. if (nodes_in_progress.Contains(node))
  245. throw new ValueError($"The tensor {tensor.name} at layer {layer.Name} is part of a cycle.");
  246. // Store the traversal order for layer sorting.
  247. if (!layer_indices.ContainsKey(layer))
  248. layer_indices[layer] = layer_indices.Count;
  249. // Propagate to all previous tensors connected to this node.
  250. nodes_in_progress.Add(node);
  251. if (!node.is_input)
  252. {
  253. foreach (var k_tensor in node.KerasInputs)
  254. {
  255. BuildMapHelper(k_tensor,
  256. finished_nodes,
  257. nodes_in_progress,
  258. nodes_in_decreasing_depth,
  259. layer_indices);
  260. }
  261. }
  262. finished_nodes.Add(node);
  263. nodes_in_progress.Remove(node);
  264. nodes_in_decreasing_depth.append(node);
  265. }
  266. protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_training = false)
  267. {
  268. return run_internal_graph(inputs, is_training);
  269. }
  270. Tensors run_internal_graph(Tensors inputs, bool training = false, Tensors mask = null)
  271. {
  272. if (mask == null)
  273. {
  274. Tensor[] masks = new Tensor[inputs.Count()];
  275. foreach (var (i, input_t) in enumerate(inputs))
  276. input_t.KerasMask = masks[i];
  277. }
  278. var tensor_dict = new Dictionary<int, Queue<Tensor>>();
  279. foreach (var (x, y) in zip(this.inputs, inputs))
  280. {
  281. var y1 = conform_to_reference_input(y, x);
  282. var x_id = x.GetHashCode();
  283. tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y1));
  284. }
  285. var depth_keys = NodesByDepth.Keys.OrderBy(x => x).Reverse().ToArray();
  286. foreach (var depth in depth_keys)
  287. {
  288. var nodes = NodesByDepth[depth];
  289. foreach (Node node in nodes)
  290. {
  291. // Input tensors already exist.
  292. if (node.is_input)
  293. continue;
  294. var layer_inputs = node.MapArguments(tensor_dict);
  295. tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name}");
  296. var outputs = node.Layer.Apply(layer_inputs, is_training: training);
  297. foreach (var output in outputs.Where(x => x != null))
  298. tf.Logger.Debug($"Depth {depth}: {node.Layer}: {node.Layer.Name} {output.TensorShape}");
  299. // Update tensor_dict for next input
  300. foreach (var (x_id, y) in zip(node.FlatOutputIds, outputs))
  301. tensor_dict[x_id] = new Queue<Tensor>(Enumerable.Range(0, tensor_usage_count[x_id]).Select(x => y));
  302. }
  303. }
  304. var output_tensors = new List<Tensor>();
  305. foreach (var x in outputs)
  306. {
  307. var x_id = x.GetHashCode();
  308. output_tensors.append(tensor_dict[x_id].Dequeue());
  309. }
  310. return output_tensors;
  311. }
  312. Tensor conform_to_reference_input(Tensor tensor, Tensor ref_input)
  313. {
  314. return tensor;
  315. }
  316. }
  317. }