You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

nest.py.cs 44 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957
  1. using System;
  2. using System.Collections;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using NumSharp;
  6. namespace Tensorflow.Util
  7. {
  8. //Functions for working with arbitrarily nested sequences of elements.
  9. //This module can perform operations on nested structures. A nested structure is a
  10. //Python sequence, tuple (including `namedtuple`), or dict that can contain
  11. //further sequences, tuples, and dicts.
  12. //The utilities here assume (and do not check) that the nested structures form a
  13. //'tree', i.e., no references in the structure of the input of these functions
  14. //should be recursive.
  15. //Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0),
  16. // (np.array([3, 4]), tf.constant([3, 4])))`
  17. //
  18. public static class nest
  19. {
  20. /// <summary>
  21. /// Untyped implementation of zip for arbitrary data
  22. ///
  23. /// Converts an list of lists or arrays [[1,2,3], [4,5,6], [7,8,9]] into a list of arrays
  24. /// representing tuples of the same index of all source arrays [[1,4,7], [2,5,9], [3,6,9]]
  25. /// </summary>
  26. /// <param name="lists">one or multiple sequences to be zipped</param>
  27. /// <returns></returns>
  28. public static IEnumerable<object[]> zip_many(params IEnumerable<object>[] lists)
  29. {
  30. if (lists.Length == 0)
  31. yield break;
  32. var first = lists[0];
  33. if (first == null)
  34. yield break;
  35. var arity = first.Count();
  36. for (int i = 0; i < arity; i++)
  37. {
  38. var array = new object[lists.Length];
  39. for (int j = 0; j < lists.Length; j++)
  40. array[j] = GetSequenceElementAt(lists[j], i);
  41. yield return array;
  42. }
  43. }
  44. private static object GetSequenceElementAt(object sequence, int i)
  45. {
  46. switch (sequence)
  47. {
  48. case Array array:
  49. return array.GetValue(i);
  50. case IList list:
  51. return list[i];
  52. default:
  53. return _yield_value(sequence).Skip(Math.Max(0, i)).FirstOrDefault();
  54. }
  55. }
  56. public static IEnumerable<(T1, T2)> zip<T1, T2>(IEnumerable<T1> e1, IEnumerable<T2> e2)
  57. => Python.zip(e1, e2);
  58. public static Dictionary<string, object> ConvertToDict(object dyn)
  59. => Python.ConvertToDict(dyn);
  60. //def _get_attrs_values(obj):
  61. // """Returns the list of values from an attrs instance."""
  62. // attrs = getattr(obj.__class__, "__attrs_attrs__")
  63. // return [getattr(obj, a.name) for a in attrs]
  64. /// <summary>
  65. /// Returns a sorted list of the dict keys, with error if keys not sortable.
  66. /// </summary>
  67. private static IEnumerable<object> _sorted(IDictionary dict_)
  68. {
  69. return dict_.Keys.OfType<object>().OrderBy(x => x);
  70. }
  71. //def _is_namedtuple(instance, strict=False):
  72. // """Returns True iff `instance` is a `namedtuple`.
  73. // Args:
  74. // instance: An instance of a Python object.
  75. // strict: If True, `instance` is considered to be a `namedtuple` only if
  76. // it is a "plain" namedtuple. For instance, a class inheriting
  77. // from a `namedtuple` will be considered to be a `namedtuple`
  78. // iff `strict=False`.
  79. // Returns:
  80. // True if `instance` is a `namedtuple`.
  81. // """
  82. // return _pywrap_tensorflow.IsNamedtuple(instance, strict)
  83. //# See the swig file (util.i) for documentation.
  84. //_is_mapping = _pywrap_tensorflow.IsMapping
  85. //_is_attrs = _pywrap_tensorflow.IsAttrs
  86. /// <summary>
  87. /// Converts the sequence `args` to the same type as `instance`.
  88. /// </summary>
  89. /// <param name="instance">an instance of `tuple`, `list`, `namedtuple`, `dict`, or
  90. /// `collections.OrderedDict`.</param>
  91. /// <param name="args">elements to be converted to the `instance` type.</param>
  92. /// <returns>`args` with the type of `instance`.</returns>
  93. private static object _sequence_like(object instance, IEnumerable<object> args)
  94. {
  95. if (is_mapping(instance))
  96. {
  97. //# Pack dictionaries in a deterministic order by sorting the keys.
  98. //# Notice this means that we ignore the original order of `OrderedDict`
  99. //# instances. This is intentional, to avoid potential bugs caused by mixing
  100. //# ordered and plain dicts (e.g., flattening a dict but using a
  101. //# corresponding `OrderedDict` to pack it back).
  102. switch (instance)
  103. {
  104. case Hashtable hash:
  105. var result = new Hashtable();
  106. foreach ((object key, object value) in zip<object, object>(_sorted(hash), args))
  107. result[key] = value;
  108. return result;
  109. }
  110. }
  111. //else if( _is_namedtuple(instance) || _is_attrs(instance))
  112. // return type(instance)(*args)
  113. else
  114. {
  115. // Not a namedtuple
  116. switch (instance)
  117. {
  118. case object[] array:
  119. var result_array = new object[args.Count()];
  120. int i = 0;
  121. foreach (var x in args)
  122. {
  123. result_array[i] = x;
  124. i++;
  125. }
  126. return result_array;
  127. case List<object> list:
  128. return new List<object>(args);
  129. default:
  130. throw new TypeError("Type of sequence not supported (yet): " + instance.GetType());
  131. }
  132. }
  133. throw new TypeError("Type of sequence not supported (yet): " + instance.GetType());
  134. }
  135. /// <summary>
  136. /// Yields the next value from the given iterable.
  137. /// </summary>
  138. private static IEnumerable<object> _yield_value(object iterable)
  139. {
  140. if (is_mapping(iterable))
  141. {
  142. var dict = iterable as IDictionary;
  143. //# Iterate through dictionaries in a deterministic order by sorting the
  144. //# keys. Notice this means that we ignore the original order of `OrderedDict`
  145. //# instances. This is intentional, to avoid potential bugs caused by mixing
  146. //# ordered and plain dicts (e.g., flattening a dict but using a
  147. //# corresponding `OrderedDict` to pack it back).
  148. foreach (var key in _sorted(dict))
  149. yield return dict[key];
  150. }
  151. //else if (_is_attrs(iterable))
  152. //{
  153. // // for value in _get_attrs_values(iterable):
  154. // // yield value
  155. //}
  156. else if (iterable is IEnumerable)
  157. {
  158. var enumerable = iterable as IEnumerable;
  159. foreach (var value in enumerable)
  160. yield return value;
  161. }
  162. else
  163. {
  164. throw new TypeError("Unexpected iterable type: " + iterable.GetType());
  165. //var jobj = JObject.FromObject(iterable);
  166. //foreach (var key in _sorted())
  167. // yield return jobj[key];
  168. }
  169. }
  170. //# See the swig file (util.i) for documentation.
  171. public static bool is_sequence(object arg)
  172. => arg is IEnumerable && !(arg is string) && !(arg is NDArray) &&
  173. !(arg.GetType().IsGenericType && arg.GetType().GetGenericTypeDefinition() == typeof(HashSet<>));
  174. public static bool is_mapping(object arg) => arg is IDictionary;
  175. //# See the swig file (util.i) for documentation.
  176. //flatten = _pywrap_tensorflow.Flatten
  177. public static List<object> flatten(object structure)
  178. {
  179. var list = new List<object>();
  180. _flatten_recursive(structure, list);
  181. return list;
  182. }
  183. private static void _flatten_recursive(object obj, List<object> list)
  184. {
  185. if (obj is string)
  186. {
  187. list.Add(obj);
  188. return;
  189. }
  190. if (obj is IDictionary)
  191. {
  192. var dict = obj as IDictionary;
  193. foreach (var key in _sorted(dict))
  194. _flatten_recursive(dict[key], list);
  195. return;
  196. }
  197. if (obj is NDArray)
  198. {
  199. list.Add(obj);
  200. return;
  201. }
  202. if (obj is IEnumerable)
  203. {
  204. var structure = obj as IEnumerable;
  205. foreach (var child in structure)
  206. _flatten_recursive(child, list);
  207. return;
  208. }
  209. list.Add(obj);
  210. }
  211. //# See the swig file (util.i) for documentation.
  212. //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
  213. //class _DotString(object):
  214. // def __str__(self):
  215. // return "."
  216. // def __repr__(self):
  217. // return "."
  218. //_DOT = _DotString()
  219. //def assert_same_structure(nest1, nest2, check_types=True):
  220. // """Asserts that two structures are nested in the same way.
  221. // Note that namedtuples with identical name and fields are always considered
  222. // to have the same shallow structure (even with `check_types=True`).
  223. // For intance, this code will print `True`:
  224. // ```python
  225. // def nt(a, b):
  226. // return collections.namedtuple('foo', 'a b')(a, b)
  227. // print(assert_same_structure(nt(0, 1), nt(2, 3)))
  228. // ```
  229. // Args:
  230. // nest1: an arbitrarily nested structure.
  231. // nest2: an arbitrarily nested structure.
  232. // check_types: if `True` (default) types of sequences are checked as well,
  233. // including the keys of dictionaries. If set to `False`, for example a
  234. // list and a tuple of objects will look the same if they have the same
  235. // size. Note that namedtuples with identical name and fields are always
  236. // considered to have the same shallow structure. Two types will also be
  237. // considered the same if they are both list subtypes (which allows "list"
  238. // and "_ListWrapper" from checkpointable dependency tracking to compare
  239. // equal).
  240. // Raises:
  241. // ValueError: If the two structures do not have the same number of elements or
  242. // if the two structures are not nested in the same way.
  243. // TypeError: If the two structures differ in the type of sequence in any of
  244. // their substructures. Only possible if `check_types` is `True`.
  245. // """
  246. // try:
  247. // _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types)
  248. // except (ValueError, TypeError) as e:
  249. // str1 = str(map_structure(lambda _: _DOT, nest1))
  250. // str2 = str(map_structure(lambda _: _DOT, nest2))
  251. // raise type(e)("%s\n"
  252. // "Entire first structure:\n%s\n"
  253. // "Entire second structure:\n%s"
  254. // % (str(e), str1, str2))
  255. //def flatten_dict_items(dictionary):
  256. // """Returns a dictionary with flattened keys and values.
  257. // This function flattens the keys and values of a dictionary, which can be
  258. // arbitrarily nested structures, and returns the flattened version of such
  259. // structures:
  260. // ```python
  261. // example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
  262. // result = {4: "a", 5: "b", 6: "c", 8: "d"}
  263. // flatten_dict_items(example_dictionary) == result
  264. // ```
  265. // The input dictionary must satisfy two properties:
  266. // 1. Its keys and values should have the same exact nested structure.
  267. // 2. The set of all flattened keys of the dictionary must not contain repeated
  268. // keys.
  269. // Args:
  270. // dictionary: the dictionary to zip
  271. // Returns:
  272. // The zipped dictionary.
  273. // Raises:
  274. // TypeError: If the input is not a dictionary.
  275. // ValueError: If any key and value have not the same structure, or if keys are
  276. // not unique.
  277. // """
  278. // if not isinstance(dictionary, (dict, _collections.Mapping)):
  279. // raise TypeError("input must be a dictionary")
  280. // flat_dictionary = {}
  281. // for i, v in _six.iteritems(dictionary):
  282. // if not is_sequence(i):
  283. // if i in flat_dictionary:
  284. // raise ValueError(
  285. // "Could not flatten dictionary: key %s is not unique." % i)
  286. // flat_dictionary[i] = v
  287. // else:
  288. // flat_i = flatten(i)
  289. // flat_v = flatten(v)
  290. // if len(flat_i) != len(flat_v):
  291. // raise ValueError(
  292. // "Could not flatten dictionary. Key had %d elements, but value had "
  293. // "%d elements. Key: %s, value: %s."
  294. // % (len(flat_i), len(flat_v), flat_i, flat_v))
  295. // for new_i, new_v in zip(flat_i, flat_v):
  296. // if new_i in flat_dictionary:
  297. // raise ValueError(
  298. // "Could not flatten dictionary: key %s is not unique."
  299. // % (new_i))
  300. // flat_dictionary[new_i] = new_v
  301. // return flat_dictionary
  302. /// <summary>
  303. /// Helper function for pack_sequence_as.
  304. /// </summary>
  305. /// <param name="structure">Substructure (list / tuple / dict) to mimic.</param>
  306. /// <param name="flat">Flattened values to output substructure for.</param>
  307. /// <param name="index">Index at which to start reading from flat.</param>
  308. /// <returns>
  309. /// The tuple(new_index, child), where:
  310. /// * new_index - the updated index into `flat` having processed `structure`.
  311. /// * packed - the subset of `flat` corresponding to `structure`,
  312. /// having started at `index`, and packed into the same nested
  313. /// format.</returns>
  314. private static (int new_index, List<object> child) _packed_nest_with_indices(object structure, List<object> flat,
  315. int index)
  316. {
  317. var packed = new List<object>();
  318. foreach (var s in _yield_value(structure))
  319. {
  320. if (is_sequence(s))
  321. {
  322. var (new_index, child) = _packed_nest_with_indices(s, flat, index);
  323. packed.Add(_sequence_like(s, child));
  324. index = new_index;
  325. }
  326. else
  327. {
  328. packed.Add(flat[index]);
  329. index += 1;
  330. }
  331. }
  332. return (index, packed);
  333. }
  334. private static int len(IEnumerable<object> x) => x.Count();
  335. /// <summary>
  336. /// Returns a given flattened sequence packed into a given structure.
  337. /// If `structure` is a scalar, `flat_sequence` must be a single-element list;
  338. /// in this case the return value is `flat_sequence[0]`.
  339. ///
  340. /// If `structure` is or contains a dict instance, the keys will be sorted to
  341. /// pack the flat sequence in deterministic order. This is true also for
  342. /// `OrderedDict` instances: their sequence order is ignored, the sorting order of
  343. /// keys is used instead. The same convention is followed in `flatten`.
  344. /// This correctly repacks dicts and `OrderedDict`s after they have been
  345. /// flattened, and also allows flattening an `OrderedDict` and then repacking it
  346. /// back using a corresponding plain dict, or vice-versa.
  347. /// Dictionaries with non-sortable keys cannot be flattened.
  348. /// </summary>
  349. /// <param name="structure">
  350. /// Nested structure, whose structure is given by nested lists,
  351. /// tuples, and dicts. Note: numpy arrays and strings are considered
  352. /// scalars.
  353. /// </param>
  354. /// <param name="flat_sequence"> flat sequence to pack.</param>
  355. /// <returns> `flat_sequence` converted to have the same recursive structure as
  356. /// `structure`.
  357. /// </returns>
  358. public static object pack_sequence_as(object structure, IEnumerable<object> flat_sequence)
  359. {
  360. List<object> flat = null;
  361. if (flat_sequence is List<object>)
  362. flat = flat_sequence as List<object>;
  363. else
  364. flat=new List<object>(flat_sequence);
  365. if (flat_sequence==null)
  366. throw new ArgumentException("flat_sequence must not be null");
  367. // if not is_sequence(flat_sequence):
  368. // raise TypeError("flat_sequence must be a sequence")
  369. if (!is_sequence(structure))
  370. {
  371. if (len(flat) != 1)
  372. throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat)} > 1");
  373. return flat.FirstOrDefault();
  374. }
  375. int final_index = 0;
  376. List<object> packed = null;
  377. try
  378. {
  379. (final_index, packed) = _packed_nest_with_indices(structure, flat, 0);
  380. if (final_index < len(flat))
  381. throw new IndexOutOfRangeException(
  382. $"Final index: {final_index} was smaller than len(flat_sequence): {len(flat)}");
  383. return _sequence_like(structure, packed);
  384. }
  385. catch (IndexOutOfRangeException)
  386. {
  387. var flat_structure = flatten(structure);
  388. if (len(flat_structure) != len(flat))
  389. {
  390. throw new ValueError("Could not pack sequence. Structure had {len(structure)} elements, but " +
  391. $"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}");
  392. }
  393. return _sequence_like(structure, packed);
  394. }
  395. catch (ArgumentOutOfRangeException)
  396. {
  397. var flat_structure = flatten(structure);
  398. if (len(flat_structure) != len(flat))
  399. {
  400. throw new ValueError("Could not pack sequence. Structure had {len(structure)} elements, but " +
  401. $"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat)}");
  402. }
  403. return _sequence_like(structure, packed);
  404. }
  405. }
  406. /// <summary>
  407. /// Applies `func` to each entry in `structure` and returns a new structure.
  408. ///
  409. /// Applies `func(x[0], x[1], ...)` where x[i] is an entry in
  410. /// `structure[i]`. All structures in `structure` must have the same arity,
  411. /// and the return value will contain the results in the same structure.
  412. /// </summary>
  413. /// <param name="func"> A callable that accepts as many arguments as there are structures.</param>
  414. /// <param name="structures">one or many IEnumerable of object</param>
  415. /// <param name="check_types">If set to
  416. /// `True` (default) the types of iterables within the structures have to be
  417. /// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError`
  418. /// exception). To allow this set this argument to `False`.
  419. /// Note that namedtuples with identical name and fields are always
  420. /// considered to have the same shallow structure.</param>
  421. /// <returns>
  422. /// A new structure with the same arity as `structure`, whose values correspond
  423. /// to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
  424. /// location in `structure[i]`. If there are different sequence types and
  425. /// `check_types` is `False` the sequence types of the first structure will be
  426. /// used.
  427. /// </returns>
  428. public static IEnumerable<object> map_structure(Func<object[], object> func, params IEnumerable<object>[] structure)
  429. {
  430. // TODO: check structure and types
  431. // for other in structure[1:]:
  432. // assert_same_structure(structure[0], other, check_types=check_types)
  433. if (structure.Length==1)
  434. {
  435. // we don't need to zip if we have only one structure
  436. return map_structure(a => func(new object[]{a}), structure[0]);
  437. }
  438. var flat_structures = structure.Select(flatten).ToArray(); // ToArray is important here!
  439. var entries = zip_many(flat_structures);
  440. var mapped_flat_structure = entries.Select(func);
  441. return _yield_value(pack_sequence_as(structure[0], mapped_flat_structure)).ToList();
  442. }
  443. /// <summary>
  444. /// Same as map_structure, but with only one structure (no combining of multiple structures)
  445. /// </summary>
  446. /// <param name="func"></param>
  447. /// <param name="structure"></param>
  448. /// <returns></returns>
  449. public static IEnumerable<object> map_structure(Func<object, object> func, IEnumerable<object> structure)
  450. {
  451. // TODO: check structure and types
  452. // for other in structure[1:]:
  453. // assert_same_structure(structure[0], other, check_types=check_types)
  454. var flat_structure = flatten(structure);
  455. var mapped_flat_structure = flat_structure.Select(func).ToList();
  456. return _yield_value(pack_sequence_as(structure, mapped_flat_structure)).ToList();
  457. }
  458. //def map_structure_with_paths(func, *structure, **kwargs):
  459. // """Applies `func` to each entry in `structure` and returns a new structure.
  460. // Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in
  461. // `structure[i]` and `path` is the common path to x[i] in the structures. All
  462. // structures in `structure` must have the same arity, and the return value will
  463. // contain the results in the same structure. Special kwarg `check_types`
  464. // determines whether the types of iterables within the structure must be the
  465. // same-- see **kwargs definition below.
  466. // Args:
  467. // func: A callable with the signature func(path, *values, **kwargs) that is
  468. // evaluated on the leaves of the structure.
  469. // *structure: A variable number of compatible structures to process.
  470. // **kwargs: Optional kwargs to be passed through to func. Special kwarg
  471. // `check_types` is not passed to func, but instead determines whether the
  472. // types of iterables within the structures have to be same (e.g.,
  473. // `map_structure(func, [1], (1,))` raises a `TypeError` exception). By
  474. // default, the types must match. To allow iteration over structures of
  475. // different types (but common arity), set this kwarg to `False`.
  476. // Returns:
  477. // A structure of the same form as the input structures whose leaves are the
  478. // result of evaluating func on corresponding leaves of the input structures.
  479. // Raises:
  480. // TypeError: If `func` is not callable or if the structures do not match
  481. // each other by depth tree.
  482. // TypeError: If `check_types` is not `False` and the two structures differ in
  483. // the type of sequence in any of their substructures.
  484. // ValueError: If no structures are provided.
  485. // """
  486. // if not callable(func):
  487. // raise TypeError("func must be callable, got: %s" % func)
  488. // if not structure:
  489. // raise ValueError("Must provide at least one structure")
  490. // check_types = kwargs.pop("check_types", True)
  491. // for other in structure[1:]:
  492. // assert_same_structure(structure[0], other, check_types=check_types)
  493. //# First set paths_and_values to:
  494. //# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]]
  495. // paths_and_values = [flatten_with_joined_string_paths(s) for s in structure]
  496. //# Now zip(*paths_and_values) would be:
  497. //# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))]
  498. //# so grouped_by_path is set to:
  499. //# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]]
  500. //# Note that p1i, ... pmi must all be equal since the structures are the same.
  501. // grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)]
  502. // return pack_sequence_as(structure[0], [
  503. // func(paths[0], *values, **kwargs) for paths, values in grouped_by_path])
  504. //def _yield_flat_up_to(shallow_tree, input_tree):
  505. // """Yields elements `input_tree` partially flattened up to `shallow_tree`."""
  506. // if is_sequence(shallow_tree):
  507. // for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
  508. // _yield_value(input_tree)):
  509. // for input_leaf in _yield_flat_up_to(shallow_branch, input_branch):
  510. // yield input_leaf
  511. // else:
  512. // yield input_tree
  513. //def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
  514. // """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
  515. // That is, this function tests if the `input_tree` structure can be created from
  516. // the `shallow_tree` structure by replacing its leaf nodes with deeper
  517. // tree structures.
  518. // Examples:
  519. // The following code will raise an exception:
  520. // ```python
  521. // shallow_tree = ["a", "b"]
  522. // input_tree = ["c", ["d", "e"], "f"]
  523. // assert_shallow_structure(shallow_tree, input_tree)
  524. // ```
  525. // The following code will not raise an exception:
  526. // ```python
  527. // shallow_tree = ["a", "b"]
  528. // input_tree = ["c", ["d", "e"]]
  529. // assert_shallow_structure(shallow_tree, input_tree)
  530. // ```
  531. // Args:
  532. // shallow_tree: an arbitrarily nested structure.
  533. // input_tree: an arbitrarily nested structure.
  534. // check_types: if `True` (default) the sequence types of `shallow_tree` and
  535. // `input_tree` have to be the same. Note that even with check_types==True,
  536. // this function will consider two different namedtuple classes with the same
  537. // name and _fields attribute to be the same class.
  538. // Raises:
  539. // TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
  540. // TypeError: If the sequence types of `shallow_tree` are different from
  541. // `input_tree`. Only raised if `check_types` is `True`.
  542. // ValueError: If the sequence lengths of `shallow_tree` are different from
  543. // `input_tree`.
  544. // """
  545. // if is_sequence(shallow_tree):
  546. // if not is_sequence(input_tree):
  547. // raise TypeError(
  548. // "If shallow structure is a sequence, input must also be a sequence. "
  549. // "Input has type: %s." % type(input_tree))
  550. // if check_types and not isinstance(input_tree, type(shallow_tree)):
  551. //# Duck-typing means that nest should be fine with two different
  552. //# namedtuples with identical name and fields.
  553. // shallow_is_namedtuple = _is_namedtuple(shallow_tree, False)
  554. // input_is_namedtuple = _is_namedtuple(input_tree, False)
  555. // if shallow_is_namedtuple and input_is_namedtuple:
  556. // if not _same_namedtuples(shallow_tree, input_tree):
  557. // raise TypeError(
  558. // "The two namedtuples don't have the same sequence type. Input "
  559. // "structure has type %s, while shallow structure has type %s."
  560. // % (type(input_tree), type(shallow_tree)))
  561. // elif not (isinstance(shallow_tree, _collections.Mapping)
  562. // and isinstance(input_tree, _collections.Mapping)):
  563. // raise TypeError(
  564. // "The two structures don't have the same sequence type. Input "
  565. // "structure has type %s, while shallow structure has type %s."
  566. // % (type(input_tree), type(shallow_tree)))
  567. // if len(input_tree) != len(shallow_tree):
  568. // raise ValueError(
  569. // "The two structures don't have the same sequence length. Input "
  570. // "structure has length %s, while shallow structure has length %s."
  571. // % (len(input_tree), len(shallow_tree)))
  572. // if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)):
  573. // if set(input_tree) != set(shallow_tree):
  574. // raise ValueError(
  575. // "The two structures don't have the same keys. Input "
  576. // "structure has keys %s, while shallow structure has keys %s." %
  577. // (list(_six.iterkeys(input_tree)),
  578. // list(_six.iterkeys(shallow_tree))))
  579. // input_tree = list(sorted(_six.iteritems(input_tree)))
  580. // shallow_tree = list(sorted(_six.iteritems(shallow_tree)))
  581. // for shallow_branch, input_branch in zip(shallow_tree, input_tree):
  582. // assert_shallow_structure(shallow_branch, input_branch,
  583. // check_types=check_types)
  584. //def flatten_up_to(shallow_tree, input_tree):
  585. // """Flattens `input_tree` up to `shallow_tree`.
  586. // Any further depth in structure in `input_tree` is retained as elements in the
  587. // partially flatten output.
  588. // If `shallow_tree` and `input_tree` are not sequences, this returns a
  589. // single-element list: `[input_tree]`.
  590. // Use Case:
  591. // Sometimes we may wish to partially flatten a nested sequence, retaining some
  592. // of the nested structure. We achieve this by specifying a shallow structure,
  593. // `shallow_tree`, we wish to flatten up to.
  594. // The input, `input_tree`, can be thought of as having the same structure as
  595. // `shallow_tree`, but with leaf nodes that are themselves tree structures.
  596. // Examples:
  597. // ```python
  598. // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
  599. // shallow_tree = [[True, True], [False, True]]
  600. // flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
  601. // flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
  602. //# Output is:
  603. //# [[2, 2], [3, 3], [4, 9], [5, 5]]
  604. //# [True, True, False, True]
  605. // ```
  606. // ```python
  607. // input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
  608. // shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
  609. // input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
  610. // input_tree_flattened = flatten(input_tree)
  611. //# Output is:
  612. //# [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
  613. //# ['a', 1, 'b', 2, 'c', 3, 'd', 4]
  614. // ```
  615. // Non-Sequence Edge Cases:
  616. // ```python
  617. // flatten_up_to(0, 0) # Output: [0]
  618. // flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]]
  619. // flatten_up_to([0, 1, 2], 0) # Output: TypeError
  620. // flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2]
  621. // ```
  622. // Args:
  623. // shallow_tree: a possibly pruned structure of input_tree.
  624. // input_tree: an arbitrarily nested structure or a scalar object.
  625. // Note, numpy arrays are considered scalars.
  626. // Returns:
  627. // A Python list, the partially flattened version of `input_tree` according to
  628. // the structure of `shallow_tree`.
  629. // Raises:
  630. // TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
  631. // TypeError: If the sequence types of `shallow_tree` are different from
  632. // `input_tree`.
  633. // ValueError: If the sequence lengths of `shallow_tree` are different from
  634. // `input_tree`.
  635. // """
  636. // assert_shallow_structure(shallow_tree, input_tree)
  637. // return list(_yield_flat_up_to(shallow_tree, input_tree))
  638. //def map_structure_up_to(shallow_tree, func, *inputs):
  639. // """Applies a function or op to a number of partially flattened inputs.
  640. // The `inputs` are flattened up to `shallow_tree` before being mapped.
  641. // Use Case:
  642. // Sometimes we wish to apply a function to a partially flattened
  643. // sequence (for example when the function itself takes sequence inputs). We
  644. // achieve this by specifying a shallow structure, `shallow_tree` we wish to
  645. // flatten up to.
  646. // The `inputs`, can be thought of as having the same structure as
  647. // `shallow_tree`, but with leaf nodes that are themselves tree structures.
  648. // This function therefore will return something with the same base structure as
  649. // `shallow_tree`.
  650. // Examples:
  651. // ```python
  652. // ab_tuple = collections.namedtuple("ab_tuple", "a, b")
  653. // op_tuple = collections.namedtuple("op_tuple", "add, mul")
  654. // inp_val = ab_tuple(a=2, b=3)
  655. // inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
  656. // out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
  657. // inp_val, inp_ops)
  658. //# Output is: ab_tuple(a=6, b=15)
  659. // ```
  660. // ```python
  661. // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
  662. // name_list = ['evens', ['odds', 'primes']]
  663. // out = map_structure_up_to(
  664. // name_list,
  665. // lambda name, sec: "first_{}_{}".format(len(sec), name),
  666. // name_list, data_list)
  667. //# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
  668. // ```
  669. // Args:
  670. // shallow_tree: a shallow tree, common to all the inputs.
  671. // func: callable which will be applied to each input individually.
  672. // *inputs: arbitrarily nested combination of objects that are compatible with
  673. // shallow_tree. The function `func` is applied to corresponding
  674. // partially flattened elements of each input, so the function must support
  675. // arity of `len(inputs)`.
  676. // Raises:
  677. // TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
  678. // TypeError: If the sequence types of `shallow_tree` are different from
  679. // `input_tree`.
  680. // ValueError: If the sequence lengths of `shallow_tree` are different from
  681. // `input_tree`.
  682. // Returns:
  683. // result of repeatedly applying `func`, with same structure as
  684. // `shallow_tree`.
  685. // """
  686. // if not inputs:
  687. // raise ValueError("Cannot map over no sequences")
  688. // for input_tree in inputs:
  689. // assert_shallow_structure(shallow_tree, input_tree)
  690. //# Flatten each input separately, apply the function to corresponding elements,
  691. //# then repack based on the structure of the first input.
  692. // all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree)
  693. // for input_tree in inputs]
  694. // results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
  695. // return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
  696. //def get_traverse_shallow_structure(traverse_fn, structure):
  697. // """Generates a shallow structure from a `traverse_fn` and `structure`.
  698. // `traverse_fn` must accept any possible subtree of `structure` and return
  699. // a depth=1 structure containing `True` or `False` values, describing which
  700. // of the top-level subtrees may be traversed. It may also
  701. // return scalar `True` or `False` "traversal is OK / not OK for all subtrees."
  702. // Examples are available in the unit tests (nest_test.py).
  703. // Args:
  704. // traverse_fn: Function taking a substructure and returning either a scalar
  705. // `bool` (whether to traverse that substructure or not) or a depth=1
  706. // shallow structure of the same type, describing which parts of the
  707. // substructure to traverse.
  708. // structure: The structure to traverse.
  709. // Returns:
  710. // A shallow structure containing python bools, which can be passed to
  711. // `map_structure_up_to` and `flatten_up_to`.
  712. // Raises:
  713. // TypeError: if `traverse_fn` returns a sequence for a non-sequence input,
  714. // or a structure with depth higher than 1 for a sequence input,
  715. // or if any leaf values in the returned structure or scalar are not type
  716. // `bool`.
  717. // """
  718. // to_traverse = traverse_fn(structure)
  719. // if not is_sequence(structure):
  720. // if not isinstance(to_traverse, bool):
  721. // raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
  722. // % (to_traverse, structure))
  723. // return to_traverse
  724. // level_traverse = []
  725. // if isinstance(to_traverse, bool):
  726. // if not to_traverse:
  727. //# Do not traverse this substructure at all. Exit early.
  728. // return False
  729. // else:
  730. //# Traverse the entire substructure.
  731. // for branch in _yield_value(structure):
  732. // level_traverse.append(
  733. // get_traverse_shallow_structure(traverse_fn, branch))
  734. // elif not is_sequence(to_traverse):
  735. // raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
  736. // % (to_traverse, structure))
  737. // else:
  738. //# Traverse some subset of this substructure.
  739. // assert_shallow_structure(to_traverse, structure)
  740. // for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)):
  741. // if not isinstance(t, bool):
  742. // raise TypeError(
  743. // "traverse_fn didn't return a depth=1 structure of bools. saw: %s "
  744. // " for structure: %s" % (to_traverse, structure))
  745. // if t:
  746. // level_traverse.append(
  747. // get_traverse_shallow_structure(traverse_fn, branch))
  748. // else:
  749. // level_traverse.append(False)
  750. // return _sequence_like(structure, level_traverse)
  751. //def yield_flat_paths(nest):
  752. // """Yields paths for some nested structure.
  753. // Paths are lists of objects which can be str-converted, which may include
  754. // integers or other types which are used as indices in a dict.
  755. // The flat list will be in the corresponding order as if you called
  756. // `snt.nest.flatten` on the structure. This is handy for naming Tensors such
  757. // the TF scope structure matches the tuple structure.
  758. // E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))`
  759. // ```shell
  760. // >>> nest.flatten(value)
  761. // [3, 23, 42]
  762. // >>> list(nest.yield_flat_paths(value))
  763. // [('a',), ('b', 'c'), ('b', 'd')]
  764. // ```
  765. // ```shell
  766. // >>> list(nest.yield_flat_paths({'a': [3]}))
  767. // [('a', 0)]
  768. // >>> list(nest.yield_flat_paths({'a': 3}))
  769. // [('a',)]
  770. // ```
  771. // Args:
  772. // nest: the value to produce a flattened paths list for.
  773. // Yields:
  774. // Tuples containing index or key values which form the path to a specific
  775. // leaf value in the nested structure.
  776. // """
  777. //# The _maybe_add_final_path_element function is used below in order to avoid
  778. //# adding trailing slashes when the sub-element recursed into is a leaf.
  779. // if isinstance(nest, (dict, _collections.Mapping)):
  780. // for key in _sorted(nest):
  781. // value = nest[key]
  782. // for sub_path in yield_flat_paths(value):
  783. // yield (key,) + sub_path
  784. // elif _is_namedtuple(nest):
  785. // for key in nest._fields:
  786. // value = getattr(nest, key)
  787. // for sub_path in yield_flat_paths(value):
  788. // yield (key,) + sub_path
  789. // elif isinstance(nest, _six.string_types):
  790. // yield ()
  791. // elif isinstance(nest, _collections.Sequence):
  792. // for idx, value in enumerate(nest):
  793. // for sub_path in yield_flat_paths(value):
  794. // yield (idx,) + sub_path
  795. // else:
  796. // yield ()
  797. //def flatten_with_joined_string_paths(structure, separator="/"):
  798. // """Returns a list of (string path, data element) tuples.
  799. // The order of tuples produced matches that of `nest.flatten`. This allows you
  800. // to flatten a nested structure while keeping information about where in the
  801. // structure each data element was located. See `nest.yield_flat_paths`
  802. // for more information.
  803. // Args:
  804. // structure: the nested structure to flatten.
  805. // separator: string to separate levels of hierarchy in the results, defaults
  806. // to '/'.
  807. // Returns:
  808. // A list of (string, data element) tuples.
  809. // """
  810. // flat_paths = yield_flat_paths(structure)
  811. // def stringify_and_join(path_elements):
  812. // return separator.join(str(path_element) for path_element in path_elements)
  813. // flat_string_paths = [stringify_and_join(path) for path in flat_paths]
  814. // return list(zip(flat_string_paths, flatten(structure)))
  815. }
  816. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。