You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TrackableView.cs 2.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. using System;
  2. using Tensorflow.Train;
  3. using System.Collections.Generic;
  4. using System.IO;
  5. using Tensorflow.Keras.Saving.SavedModel;
  6. namespace Tensorflow.Checkpoint;
  7. public class TrackableView
  8. {
  9. protected WeakReference<Trackable> _root_ref;
  10. public TrackableView(Trackable obj)
  11. {
  12. _root_ref = new WeakReference<Trackable>(obj);
  13. }
  14. public TrackableView(WeakReference<Trackable> obj)
  15. {
  16. _root_ref = obj;
  17. }
  18. public virtual IDictionary<string, Trackable> children(Trackable obj, SaveType save_type = SaveType.CHECKPOINT, IDictionary<string, IDictionary<Trackable, ISerializedAttributes>>? cache = null)
  19. {
  20. obj._maybe_initialize_trackable();
  21. Dictionary<string, Trackable> children = new();
  22. // Note: in python the return type of `Trackable._trackable_children` is not fixed.
  23. // Therefore it uses `convert_to_trackable` to have an extra process.
  24. foreach (var pair in obj._trackable_children(save_type, cache))
  25. {
  26. children[pair.Key] = pair.Value;
  27. }
  28. return children;
  29. }
  30. public Trackable Root
  31. {
  32. get
  33. {
  34. if (_root_ref.TryGetTarget(out Trackable res))
  35. {
  36. return res;
  37. }
  38. else
  39. {
  40. throw new InvalidDataException(
  41. "Cannot get the object from the weak reference. Please consider if a null reference is passed to the constructor.");
  42. }
  43. }
  44. }
  45. /// <summary>
  46. /// Returns a list of all nodes and its paths from self.root using a breadth first traversal.
  47. /// Corresponding to tensorflow/python/checkpoint/trackable_view.Trackable._descendants_with_paths
  48. /// </summary>
  49. protected (IList<Trackable>, IDictionary<Trackable, IEnumerable<TrackableReference>>) _descendants_with_paths()
  50. {
  51. List<Trackable> bfs_sorted = new();
  52. Queue<Trackable> to_visit = new();
  53. to_visit.Enqueue(Root);
  54. Dictionary<Trackable, IEnumerable<TrackableReference>> node_paths = new();
  55. node_paths[this.Root] = new List<TrackableReference>();
  56. while (!to_visit.empty())
  57. {
  58. var current_trackable = to_visit.Dequeue();
  59. bfs_sorted.Add(current_trackable);
  60. var children_dict = this.children(current_trackable);
  61. foreach (var name in children_dict.Keys)
  62. {
  63. var dependency = children_dict[name];
  64. if (!node_paths.ContainsKey(dependency))
  65. {
  66. var list = new List<TrackableReference>(node_paths[current_trackable]);
  67. list.Add(new TrackableReference(name, dependency));
  68. node_paths[dependency] = list;
  69. to_visit.Enqueue(dependency);
  70. }
  71. }
  72. }
  73. return (bfs_sorted, node_paths);
  74. }
  75. }