| @@ -0,0 +1,871 @@ | |||||
| using System; | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Util | |||||
| { | |||||
| //Functions for working with arbitrarily nested sequences of elements. | |||||
| //This module can perform operations on nested structures. A nested structure is a | |||||
| //Python sequence, tuple (including `namedtuple`), or dict that can contain | |||||
| //further sequences, tuples, and dicts. | |||||
| //The utilities here assume (and do not check) that the nested structures form a | |||||
| //'tree', i.e., no references in the structure of the input of these functions | |||||
| //should be recursive. | |||||
| //Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0), | |||||
| // (np.array([3, 4]), tf.constant([3, 4])))` | |||||
| // | |||||
| public static class nest | |||||
| { | |||||
| //def _get_attrs_values(obj): | |||||
| // """Returns the list of values from an attrs instance.""" | |||||
| // attrs = getattr(obj.__class__, "__attrs_attrs__") | |||||
| // return [getattr(obj, a.name) for a in attrs] | |||||
| /// <summary> | |||||
| /// Returns a sorted list of the dict keys, with error if keys not sortable. | |||||
| /// </summary> | |||||
| private static IEnumerable<string> _sorted(IDictionary dict_) | |||||
| { | |||||
| return dict_.Keys.OfType<string>().OrderBy(x => x); | |||||
| } | |||||
| //def _is_namedtuple(instance, strict=False): | |||||
| // """Returns True iff `instance` is a `namedtuple`. | |||||
| // Args: | |||||
| // instance: An instance of a Python object. | |||||
| // strict: If True, `instance` is considered to be a `namedtuple` only if | |||||
| // it is a "plain" namedtuple. For instance, a class inheriting | |||||
| // from a `namedtuple` will be considered to be a `namedtuple` | |||||
| // iff `strict=False`. | |||||
| // Returns: | |||||
| // True if `instance` is a `namedtuple`. | |||||
| // """ | |||||
| // return _pywrap_tensorflow.IsNamedtuple(instance, strict) | |||||
| //# See the swig file (util.i) for documentation. | |||||
| //_is_mapping = _pywrap_tensorflow.IsMapping | |||||
| //_is_attrs = _pywrap_tensorflow.IsAttrs | |||||
| /// <summary> | |||||
| /// Converts the sequence `args` to the same type as `instance`. | |||||
| /// </summary> | |||||
| /// <param name="instance">an instance of `tuple`, `list`, `namedtuple`, `dict`, or | |||||
| /// `collections.OrderedDict`.</param> | |||||
| /// <param name="args">elements to be converted to the `instance` type.</param> | |||||
| /// <returns>`args` with the type of `instance`.</returns> | |||||
| private static object _sequence_like(object instance, IEnumerable<object> args) | |||||
| { | |||||
| if (is_mapping(instance)) | |||||
| { | |||||
| //# Pack dictionaries in a deterministic order by sorting the keys. | |||||
| //# Notice this means that we ignore the original order of `OrderedDict` | |||||
| //# instances. This is intentional, to avoid potential bugs caused by mixing | |||||
| //# ordered and plain dicts (e.g., flattening a dict but using a | |||||
| //# corresponding `OrderedDict` to pack it back). | |||||
| // result = dict(zip(_sorted(instance), args)) | |||||
| // return type(instance)((key, result[key]) for key in _six.iterkeys(instance)) | |||||
| } | |||||
| //else if( _is_namedtuple(instance) || _is_attrs(instance)) | |||||
| // return type(instance)(*args) | |||||
| else | |||||
| { | |||||
| // Not a namedtuple | |||||
| switch (instance) | |||||
| { | |||||
| case object[] array: | |||||
| var result_array = new object[args.Count()]; | |||||
| int i = 0; | |||||
| foreach (var x in args) | |||||
| { | |||||
| result_array[i] = x; | |||||
| i++; | |||||
| } | |||||
| return result_array; | |||||
| case List<object> list: | |||||
| return new List<object>(args); | |||||
| default: | |||||
| throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | |||||
| } | |||||
| } | |||||
| throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | |||||
| } | |||||
| /// <summary> | |||||
| /// Yields the next value from the given iterable. | |||||
| /// </summary> | |||||
| private static IEnumerable<object> _yield_value(object iterable) | |||||
| { | |||||
| if (is_mapping(iterable)) | |||||
| { | |||||
| var dict = iterable as IDictionary; | |||||
| //# Iterate through dictionaries in a deterministic order by sorting the | |||||
| //# keys. Notice this means that we ignore the original order of `OrderedDict` | |||||
| //# instances. This is intentional, to avoid potential bugs caused by mixing | |||||
| //# ordered and plain dicts (e.g., flattening a dict but using a | |||||
| //# corresponding `OrderedDict` to pack it back). | |||||
| foreach (var key in _sorted(dict)) | |||||
| yield return dict[key]; | |||||
| } | |||||
| //else if (_is_attrs(iterable)) | |||||
| //{ | |||||
| // // for value in _get_attrs_values(iterable): | |||||
| // // yield value | |||||
| //} | |||||
| else if (iterable is IEnumerable) | |||||
| { | |||||
| var enumerable = iterable as IEnumerable; | |||||
| foreach (var value in enumerable) | |||||
| yield return value; | |||||
| } | |||||
| else | |||||
| { | |||||
| throw new TypeError("Unexpected iterable type: " + iterable.GetType()); | |||||
| //var jobj = JObject.FromObject(iterable); | |||||
| //foreach (var key in _sorted()) | |||||
| // yield return jobj[key]; | |||||
| } | |||||
| } | |||||
| //# See the swig file (util.i) for documentation. | |||||
| public static bool is_sequence(object arg) => arg is IEnumerable && !(arg is string); | |||||
| public static bool is_mapping(object arg) => arg is IDictionary; | |||||
| //# See the swig file (util.i) for documentation. | |||||
| //flatten = _pywrap_tensorflow.Flatten | |||||
| public static List<object> flatten(object structure) | |||||
| { | |||||
| var list = new List<object>(); | |||||
| _flatten_recursive(structure, list); | |||||
| return list; | |||||
| } | |||||
| private static void _flatten_recursive(object obj, List<object> list) | |||||
| { | |||||
| if (obj is string) | |||||
| { | |||||
| list.Add(obj); | |||||
| return; | |||||
| } | |||||
| if (obj is IDictionary) | |||||
| { | |||||
| var dict = obj as IDictionary; | |||||
| foreach (var key in _sorted(dict)) | |||||
| _flatten_recursive(dict[key], list); | |||||
| return; | |||||
| } | |||||
| if (obj is NDArray) | |||||
| { | |||||
| list.Add(obj); | |||||
| return; | |||||
| } | |||||
| if (obj is IEnumerable) | |||||
| { | |||||
| var structure = obj as IEnumerable; | |||||
| foreach (var child in structure) | |||||
| _flatten_recursive(child, list); | |||||
| return; | |||||
| } | |||||
| list.Add(obj); | |||||
| } | |||||
| //# See the swig file (util.i) for documentation. | |||||
| //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples | |||||
| //class _DotString(object): | |||||
| // def __str__(self): | |||||
| // return "." | |||||
| // def __repr__(self): | |||||
| // return "." | |||||
| //_DOT = _DotString() | |||||
| //def assert_same_structure(nest1, nest2, check_types=True): | |||||
| // """Asserts that two structures are nested in the same way. | |||||
| // Note that namedtuples with identical name and fields are always considered | |||||
| // to have the same shallow structure (even with `check_types=True`). | |||||
| // For intance, this code will print `True`: | |||||
| // ```python | |||||
| // def nt(a, b): | |||||
| // return collections.namedtuple('foo', 'a b')(a, b) | |||||
| // print(assert_same_structure(nt(0, 1), nt(2, 3))) | |||||
| // ``` | |||||
| // Args: | |||||
| // nest1: an arbitrarily nested structure. | |||||
| // nest2: an arbitrarily nested structure. | |||||
| // check_types: if `True` (default) types of sequences are checked as well, | |||||
| // including the keys of dictionaries. If set to `False`, for example a | |||||
| // list and a tuple of objects will look the same if they have the same | |||||
| // size. Note that namedtuples with identical name and fields are always | |||||
| // considered to have the same shallow structure. Two types will also be | |||||
| // considered the same if they are both list subtypes (which allows "list" | |||||
| // and "_ListWrapper" from checkpointable dependency tracking to compare | |||||
| // equal). | |||||
| // Raises: | |||||
| // ValueError: If the two structures do not have the same number of elements or | |||||
| // if the two structures are not nested in the same way. | |||||
| // TypeError: If the two structures differ in the type of sequence in any of | |||||
| // their substructures. Only possible if `check_types` is `True`. | |||||
| // """ | |||||
| // try: | |||||
| // _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types) | |||||
| // except (ValueError, TypeError) as e: | |||||
| // str1 = str(map_structure(lambda _: _DOT, nest1)) | |||||
| // str2 = str(map_structure(lambda _: _DOT, nest2)) | |||||
| // raise type(e)("%s\n" | |||||
| // "Entire first structure:\n%s\n" | |||||
| // "Entire second structure:\n%s" | |||||
| // % (str(e), str1, str2)) | |||||
| //def flatten_dict_items(dictionary): | |||||
| // """Returns a dictionary with flattened keys and values. | |||||
| // This function flattens the keys and values of a dictionary, which can be | |||||
| // arbitrarily nested structures, and returns the flattened version of such | |||||
| // structures: | |||||
| // ```python | |||||
| // example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} | |||||
| // result = {4: "a", 5: "b", 6: "c", 8: "d"} | |||||
| // flatten_dict_items(example_dictionary) == result | |||||
| // ``` | |||||
| // The input dictionary must satisfy two properties: | |||||
| // 1. Its keys and values should have the same exact nested structure. | |||||
| // 2. The set of all flattened keys of the dictionary must not contain repeated | |||||
| // keys. | |||||
| // Args: | |||||
| // dictionary: the dictionary to zip | |||||
| // Returns: | |||||
| // The zipped dictionary. | |||||
| // Raises: | |||||
| // TypeError: If the input is not a dictionary. | |||||
| // ValueError: If any key and value have not the same structure, or if keys are | |||||
| // not unique. | |||||
| // """ | |||||
| // if not isinstance(dictionary, (dict, _collections.Mapping)): | |||||
| // raise TypeError("input must be a dictionary") | |||||
| // flat_dictionary = {} | |||||
| // for i, v in _six.iteritems(dictionary): | |||||
| // if not is_sequence(i): | |||||
| // if i in flat_dictionary: | |||||
| // raise ValueError( | |||||
| // "Could not flatten dictionary: key %s is not unique." % i) | |||||
| // flat_dictionary[i] = v | |||||
| // else: | |||||
| // flat_i = flatten(i) | |||||
| // flat_v = flatten(v) | |||||
| // if len(flat_i) != len(flat_v): | |||||
| // raise ValueError( | |||||
| // "Could not flatten dictionary. Key had %d elements, but value had " | |||||
| // "%d elements. Key: %s, value: %s." | |||||
| // % (len(flat_i), len(flat_v), flat_i, flat_v)) | |||||
| // for new_i, new_v in zip(flat_i, flat_v): | |||||
| // if new_i in flat_dictionary: | |||||
| // raise ValueError( | |||||
| // "Could not flatten dictionary: key %s is not unique." | |||||
| // % (new_i)) | |||||
| // flat_dictionary[new_i] = new_v | |||||
| // return flat_dictionary | |||||
| /// <summary> | |||||
| /// Helper function for pack_sequence_as. | |||||
| /// </summary> | |||||
| /// <param name="structure">Substructure (list / tuple / dict) to mimic.</param> | |||||
| /// <param name="flat">Flattened values to output substructure for.</param> | |||||
| /// <param name="index">Index at which to start reading from flat.</param> | |||||
| /// <returns> | |||||
| /// The tuple(new_index, child), where: | |||||
| /// * new_index - the updated index into `flat` having processed `structure`. | |||||
| /// * packed - the subset of `flat` corresponding to `structure`, | |||||
| /// having started at `index`, and packed into the same nested | |||||
| /// format.</returns> | |||||
| private static (int new_index, List<object> child) _packed_nest_with_indices(object structure, List<object> flat, | |||||
| int index) | |||||
| { | |||||
| var packed = new List<object>(); | |||||
| foreach (var s in _yield_value(structure)) | |||||
| { | |||||
| if (is_sequence(s)) | |||||
| { | |||||
| var (new_index, child) = _packed_nest_with_indices(s, flat, index); | |||||
| packed.Add(_sequence_like(s, child)); | |||||
| index = new_index; | |||||
| } | |||||
| else | |||||
| { | |||||
| packed.Add(flat[index]); | |||||
| index += 1; | |||||
| } | |||||
| } | |||||
| return (index, packed); | |||||
| } | |||||
| private static int len(IEnumerable<object> x) => x.Count(); | |||||
| /// <summary> | |||||
| /// Returns a given flattened sequence packed into a given structure. | |||||
| /// If `structure` is a scalar, `flat_sequence` must be a single-element list; | |||||
| /// in this case the return value is `flat_sequence[0]`. | |||||
| /// | |||||
| /// If `structure` is or contains a dict instance, the keys will be sorted to | |||||
| /// pack the flat sequence in deterministic order. This is true also for | |||||
| /// `OrderedDict` instances: their sequence order is ignored, the sorting order of | |||||
| /// keys is used instead. The same convention is followed in `flatten`. | |||||
| /// This correctly repacks dicts and `OrderedDict`s after they have been | |||||
| /// flattened, and also allows flattening an `OrderedDict` and then repacking it | |||||
| /// back using a corresponding plain dict, or vice-versa. | |||||
| /// Dictionaries with non-sortable keys cannot be flattened. | |||||
| /// </summary> | |||||
| /// <param name="structure"> | |||||
| /// Nested structure, whose structure is given by nested lists, | |||||
| /// tuples, and dicts. Note: numpy arrays and strings are considered | |||||
| /// scalars. | |||||
| /// </param> | |||||
| /// <param name="flat_sequence"> flat sequence to pack.</param> | |||||
| /// <returns> `flat_sequence` converted to have the same recursive structure as | |||||
| /// `structure`. | |||||
| /// </returns> | |||||
| public static object pack_sequence_as(object structure, List<object> flat_sequence) | |||||
| { | |||||
| if (flat_sequence == null) | |||||
| throw new ArgumentException("flat_sequence must not be null"); | |||||
| // if not is_sequence(flat_sequence): | |||||
| // raise TypeError("flat_sequence must be a sequence") | |||||
| if (!is_sequence(structure)) | |||||
| { | |||||
| if (len(flat_sequence) != 1) | |||||
| throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat_sequence)} > 1"); | |||||
| return flat_sequence.FirstOrDefault(); | |||||
| } | |||||
| int final_index = 0; | |||||
| List<object> packed = null; | |||||
| try | |||||
| { | |||||
| (final_index, packed) = _packed_nest_with_indices(structure, flat_sequence, 0); | |||||
| if (final_index < len(flat_sequence)) | |||||
| throw new IndexOutOfRangeException($"Final index: { final_index} was smaller than len(flat_sequence): { len(flat_sequence) }"); | |||||
| } | |||||
| catch (IndexOutOfRangeException) | |||||
| { | |||||
| var flat_structure = flatten(structure); | |||||
| if (len(flat_structure) != len(flat_sequence)) | |||||
| { | |||||
| throw new ValueError("Could not pack sequence. Structure had %d elements, but " + | |||||
| $"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat_sequence)}"); | |||||
| } | |||||
| return _sequence_like(structure, packed); | |||||
| } | |||||
| return packed; | |||||
| } | |||||
| /// <summary> | |||||
| /// Applies `func` to each entry in `structure` and returns a new structure. | |||||
| /// | |||||
| /// Applies `func(x[0], x[1], ...)` where x[i] is an entry in | |||||
| /// `structure[i]`. All structures in `structure` must have the same arity, | |||||
| /// and the return value will contain the results in the same structure. | |||||
| /// </summary> | |||||
| /// <typeparam name="T"></typeparam> | |||||
| /// <typeparam name="U"></typeparam> | |||||
| /// <param name="func"> A callable that accepts as many arguments as there are structures.</param> | |||||
| /// <param name="structure">scalar, or tuple or list of constructed scalars and/or other | |||||
| /// tuples/lists, or scalars. Note: numpy arrays are considered as scalars.</param> | |||||
| /// <param name="check_types">If set to | |||||
| /// `True` (default) the types of iterables within the structures have to be | |||||
| /// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` | |||||
| /// exception). To allow this set this argument to `False`. | |||||
| /// Note that namedtuples with identical name and fields are always | |||||
| /// considered to have the same shallow structure.</param> | |||||
| /// <returns> | |||||
| /// A new structure with the same arity as `structure`, whose values correspond | |||||
| /// to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding | |||||
| /// location in `structure[i]`. If there are different sequence types and | |||||
| /// `check_types` is `False` the sequence types of the first structure will be | |||||
| /// used. | |||||
| /// </returns> | |||||
| public static IEnumerable<U> map_structure<T, U>(Func<T, U> func, IEnumerable<T> structure, bool check_types = false) | |||||
| { | |||||
| // for other in structure[1:]: | |||||
| // assert_same_structure(structure[0], other, check_types=check_types) | |||||
| // flat_structure = [flatten(s) for s in structure] | |||||
| // entries = zip(*flat_structure) | |||||
| // return pack_sequence_as( | |||||
| // structure[0], [func(*x) for x in entries]) | |||||
| return null; | |||||
| } | |||||
| //def map_structure_with_paths(func, *structure, **kwargs): | |||||
| // """Applies `func` to each entry in `structure` and returns a new structure. | |||||
| // Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in | |||||
| // `structure[i]` and `path` is the common path to x[i] in the structures. All | |||||
| // structures in `structure` must have the same arity, and the return value will | |||||
| // contain the results in the same structure. Special kwarg `check_types` | |||||
| // determines whether the types of iterables within the structure must be the | |||||
| // same-- see **kwargs definition below. | |||||
| // Args: | |||||
| // func: A callable with the signature func(path, *values, **kwargs) that is | |||||
| // evaluated on the leaves of the structure. | |||||
| // *structure: A variable number of compatible structures to process. | |||||
| // **kwargs: Optional kwargs to be passed through to func. Special kwarg | |||||
| // `check_types` is not passed to func, but instead determines whether the | |||||
| // types of iterables within the structures have to be same (e.g., | |||||
| // `map_structure(func, [1], (1,))` raises a `TypeError` exception). By | |||||
| // default, the types must match. To allow iteration over structures of | |||||
| // different types (but common arity), set this kwarg to `False`. | |||||
| // Returns: | |||||
| // A structure of the same form as the input structures whose leaves are the | |||||
| // result of evaluating func on corresponding leaves of the input structures. | |||||
| // Raises: | |||||
| // TypeError: If `func` is not callable or if the structures do not match | |||||
| // each other by depth tree. | |||||
| // TypeError: If `check_types` is not `False` and the two structures differ in | |||||
| // the type of sequence in any of their substructures. | |||||
| // ValueError: If no structures are provided. | |||||
| // """ | |||||
| // if not callable(func): | |||||
| // raise TypeError("func must be callable, got: %s" % func) | |||||
| // if not structure: | |||||
| // raise ValueError("Must provide at least one structure") | |||||
| // check_types = kwargs.pop("check_types", True) | |||||
| // for other in structure[1:]: | |||||
| // assert_same_structure(structure[0], other, check_types=check_types) | |||||
| //# First set paths_and_values to: | |||||
| //# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]] | |||||
| // paths_and_values = [flatten_with_joined_string_paths(s) for s in structure] | |||||
| //# Now zip(*paths_and_values) would be: | |||||
| //# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))] | |||||
| //# so grouped_by_path is set to: | |||||
| //# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]] | |||||
| //# Note that p1i, ... pmi must all be equal since the structures are the same. | |||||
| // grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)] | |||||
| // return pack_sequence_as(structure[0], [ | |||||
| // func(paths[0], *values, **kwargs) for paths, values in grouped_by_path]) | |||||
| //def _yield_flat_up_to(shallow_tree, input_tree): | |||||
| // """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" | |||||
| // if is_sequence(shallow_tree): | |||||
| // for shallow_branch, input_branch in zip(_yield_value(shallow_tree), | |||||
| // _yield_value(input_tree)): | |||||
| // for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): | |||||
| // yield input_leaf | |||||
| // else: | |||||
| // yield input_tree | |||||
| //def assert_shallow_structure(shallow_tree, input_tree, check_types=True): | |||||
| // """Asserts that `shallow_tree` is a shallow structure of `input_tree`. | |||||
| // That is, this function tests if the `input_tree` structure can be created from | |||||
| // the `shallow_tree` structure by replacing its leaf nodes with deeper | |||||
| // tree structures. | |||||
| // Examples: | |||||
| // The following code will raise an exception: | |||||
| // ```python | |||||
| // shallow_tree = ["a", "b"] | |||||
| // input_tree = ["c", ["d", "e"], "f"] | |||||
| // assert_shallow_structure(shallow_tree, input_tree) | |||||
| // ``` | |||||
| // The following code will not raise an exception: | |||||
| // ```python | |||||
| // shallow_tree = ["a", "b"] | |||||
| // input_tree = ["c", ["d", "e"]] | |||||
| // assert_shallow_structure(shallow_tree, input_tree) | |||||
| // ``` | |||||
| // Args: | |||||
| // shallow_tree: an arbitrarily nested structure. | |||||
| // input_tree: an arbitrarily nested structure. | |||||
| // check_types: if `True` (default) the sequence types of `shallow_tree` and | |||||
| // `input_tree` have to be the same. Note that even with check_types==True, | |||||
| // this function will consider two different namedtuple classes with the same | |||||
| // name and _fields attribute to be the same class. | |||||
| // Raises: | |||||
| // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||||
| // TypeError: If the sequence types of `shallow_tree` are different from | |||||
| // `input_tree`. Only raised if `check_types` is `True`. | |||||
| // ValueError: If the sequence lengths of `shallow_tree` are different from | |||||
| // `input_tree`. | |||||
| // """ | |||||
| // if is_sequence(shallow_tree): | |||||
| // if not is_sequence(input_tree): | |||||
| // raise TypeError( | |||||
| // "If shallow structure is a sequence, input must also be a sequence. " | |||||
| // "Input has type: %s." % type(input_tree)) | |||||
| // if check_types and not isinstance(input_tree, type(shallow_tree)): | |||||
| //# Duck-typing means that nest should be fine with two different | |||||
| //# namedtuples with identical name and fields. | |||||
| // shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) | |||||
| // input_is_namedtuple = _is_namedtuple(input_tree, False) | |||||
| // if shallow_is_namedtuple and input_is_namedtuple: | |||||
| // if not _same_namedtuples(shallow_tree, input_tree): | |||||
| // raise TypeError( | |||||
| // "The two namedtuples don't have the same sequence type. Input " | |||||
| // "structure has type %s, while shallow structure has type %s." | |||||
| // % (type(input_tree), type(shallow_tree))) | |||||
| // elif not (isinstance(shallow_tree, _collections.Mapping) | |||||
| // and isinstance(input_tree, _collections.Mapping)): | |||||
| // raise TypeError( | |||||
| // "The two structures don't have the same sequence type. Input " | |||||
| // "structure has type %s, while shallow structure has type %s." | |||||
| // % (type(input_tree), type(shallow_tree))) | |||||
| // if len(input_tree) != len(shallow_tree): | |||||
| // raise ValueError( | |||||
| // "The two structures don't have the same sequence length. Input " | |||||
| // "structure has length %s, while shallow structure has length %s." | |||||
| // % (len(input_tree), len(shallow_tree))) | |||||
| // if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)): | |||||
| // if set(input_tree) != set(shallow_tree): | |||||
| // raise ValueError( | |||||
| // "The two structures don't have the same keys. Input " | |||||
| // "structure has keys %s, while shallow structure has keys %s." % | |||||
| // (list(_six.iterkeys(input_tree)), | |||||
| // list(_six.iterkeys(shallow_tree)))) | |||||
| // input_tree = list(sorted(_six.iteritems(input_tree))) | |||||
| // shallow_tree = list(sorted(_six.iteritems(shallow_tree))) | |||||
| // for shallow_branch, input_branch in zip(shallow_tree, input_tree): | |||||
| // assert_shallow_structure(shallow_branch, input_branch, | |||||
| // check_types=check_types) | |||||
| //def flatten_up_to(shallow_tree, input_tree): | |||||
| // """Flattens `input_tree` up to `shallow_tree`. | |||||
| // Any further depth in structure in `input_tree` is retained as elements in the | |||||
| // partially flatten output. | |||||
| // If `shallow_tree` and `input_tree` are not sequences, this returns a | |||||
| // single-element list: `[input_tree]`. | |||||
| // Use Case: | |||||
| // Sometimes we may wish to partially flatten a nested sequence, retaining some | |||||
| // of the nested structure. We achieve this by specifying a shallow structure, | |||||
| // `shallow_tree`, we wish to flatten up to. | |||||
| // The input, `input_tree`, can be thought of as having the same structure as | |||||
| // `shallow_tree`, but with leaf nodes that are themselves tree structures. | |||||
| // Examples: | |||||
| // ```python | |||||
| // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||||
| // shallow_tree = [[True, True], [False, True]] | |||||
| // flattened_input_tree = flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) | |||||
| //# Output is: | |||||
| //# [[2, 2], [3, 3], [4, 9], [5, 5]] | |||||
| //# [True, True, False, True] | |||||
| // ``` | |||||
| // ```python | |||||
| // input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] | |||||
| // shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] | |||||
| // input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) | |||||
| // input_tree_flattened = flatten(input_tree) | |||||
| //# Output is: | |||||
| //# [('a', 1), ('b', 2), ('c', 3), ('d', 4)] | |||||
| //# ['a', 1, 'b', 2, 'c', 3, 'd', 4] | |||||
| // ``` | |||||
| // Non-Sequence Edge Cases: | |||||
| // ```python | |||||
| // flatten_up_to(0, 0) # Output: [0] | |||||
| // flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] | |||||
| // flatten_up_to([0, 1, 2], 0) # Output: TypeError | |||||
| // flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] | |||||
| // ``` | |||||
| // Args: | |||||
| // shallow_tree: a possibly pruned structure of input_tree. | |||||
| // input_tree: an arbitrarily nested structure or a scalar object. | |||||
| // Note, numpy arrays are considered scalars. | |||||
| // Returns: | |||||
| // A Python list, the partially flattened version of `input_tree` according to | |||||
| // the structure of `shallow_tree`. | |||||
| // Raises: | |||||
| // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||||
| // TypeError: If the sequence types of `shallow_tree` are different from | |||||
| // `input_tree`. | |||||
| // ValueError: If the sequence lengths of `shallow_tree` are different from | |||||
| // `input_tree`. | |||||
| // """ | |||||
| // assert_shallow_structure(shallow_tree, input_tree) | |||||
| // return list(_yield_flat_up_to(shallow_tree, input_tree)) | |||||
| //def map_structure_up_to(shallow_tree, func, *inputs): | |||||
| // """Applies a function or op to a number of partially flattened inputs. | |||||
| // The `inputs` are flattened up to `shallow_tree` before being mapped. | |||||
| // Use Case: | |||||
| // Sometimes we wish to apply a function to a partially flattened | |||||
| // sequence (for example when the function itself takes sequence inputs). We | |||||
| // achieve this by specifying a shallow structure, `shallow_tree` we wish to | |||||
| // flatten up to. | |||||
| // The `inputs`, can be thought of as having the same structure as | |||||
| // `shallow_tree`, but with leaf nodes that are themselves tree structures. | |||||
| // This function therefore will return something with the same base structure as | |||||
| // `shallow_tree`. | |||||
| // Examples: | |||||
| // ```python | |||||
| // ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||||
| // op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||||
| // inp_val = ab_tuple(a=2, b=3) | |||||
| // inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) | |||||
| // out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, | |||||
| // inp_val, inp_ops) | |||||
| //# Output is: ab_tuple(a=6, b=15) | |||||
| // ``` | |||||
| // ```python | |||||
| // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||||
| // name_list = ['evens', ['odds', 'primes']] | |||||
| // out = map_structure_up_to( | |||||
| // name_list, | |||||
| // lambda name, sec: "first_{}_{}".format(len(sec), name), | |||||
| // name_list, data_list) | |||||
| //# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] | |||||
| // ``` | |||||
| // Args: | |||||
| // shallow_tree: a shallow tree, common to all the inputs. | |||||
| // func: callable which will be applied to each input individually. | |||||
| // *inputs: arbitrarily nested combination of objects that are compatible with | |||||
| // shallow_tree. The function `func` is applied to corresponding | |||||
| // partially flattened elements of each input, so the function must support | |||||
| // arity of `len(inputs)`. | |||||
| // Raises: | |||||
| // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||||
| // TypeError: If the sequence types of `shallow_tree` are different from | |||||
| // `input_tree`. | |||||
| // ValueError: If the sequence lengths of `shallow_tree` are different from | |||||
| // `input_tree`. | |||||
| // Returns: | |||||
| // result of repeatedly applying `func`, with same structure as | |||||
| // `shallow_tree`. | |||||
| // """ | |||||
| // if not inputs: | |||||
| // raise ValueError("Cannot map over no sequences") | |||||
| // for input_tree in inputs: | |||||
| // assert_shallow_structure(shallow_tree, input_tree) | |||||
| //# Flatten each input separately, apply the function to corresponding elements, | |||||
| //# then repack based on the structure of the first input. | |||||
| // all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) | |||||
| // for input_tree in inputs] | |||||
| // results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] | |||||
| // return pack_sequence_as(structure=shallow_tree, flat_sequence=results) | |||||
| //def get_traverse_shallow_structure(traverse_fn, structure): | |||||
| // """Generates a shallow structure from a `traverse_fn` and `structure`. | |||||
| // `traverse_fn` must accept any possible subtree of `structure` and return | |||||
| // a depth=1 structure containing `True` or `False` values, describing which | |||||
| // of the top-level subtrees may be traversed. It may also | |||||
| // return scalar `True` or `False` "traversal is OK / not OK for all subtrees." | |||||
| // Examples are available in the unit tests (nest_test.py). | |||||
| // Args: | |||||
| // traverse_fn: Function taking a substructure and returning either a scalar | |||||
| // `bool` (whether to traverse that substructure or not) or a depth=1 | |||||
| // shallow structure of the same type, describing which parts of the | |||||
| // substructure to traverse. | |||||
| // structure: The structure to traverse. | |||||
| // Returns: | |||||
| // A shallow structure containing python bools, which can be passed to | |||||
| // `map_structure_up_to` and `flatten_up_to`. | |||||
| // Raises: | |||||
| // TypeError: if `traverse_fn` returns a sequence for a non-sequence input, | |||||
| // or a structure with depth higher than 1 for a sequence input, | |||||
| // or if any leaf values in the returned structure or scalar are not type | |||||
| // `bool`. | |||||
| // """ | |||||
| // to_traverse = traverse_fn(structure) | |||||
| // if not is_sequence(structure): | |||||
| // if not isinstance(to_traverse, bool): | |||||
| // raise TypeError("traverse_fn returned structure: %s for non-structure: %s" | |||||
| // % (to_traverse, structure)) | |||||
| // return to_traverse | |||||
| // level_traverse = [] | |||||
| // if isinstance(to_traverse, bool): | |||||
| // if not to_traverse: | |||||
| //# Do not traverse this substructure at all. Exit early. | |||||
| // return False | |||||
| // else: | |||||
| //# Traverse the entire substructure. | |||||
| // for branch in _yield_value(structure): | |||||
| // level_traverse.append( | |||||
| // get_traverse_shallow_structure(traverse_fn, branch)) | |||||
| // elif not is_sequence(to_traverse): | |||||
| // raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" | |||||
| // % (to_traverse, structure)) | |||||
| // else: | |||||
| //# Traverse some subset of this substructure. | |||||
| // assert_shallow_structure(to_traverse, structure) | |||||
| // for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): | |||||
| // if not isinstance(t, bool): | |||||
| // raise TypeError( | |||||
| // "traverse_fn didn't return a depth=1 structure of bools. saw: %s " | |||||
| // " for structure: %s" % (to_traverse, structure)) | |||||
| // if t: | |||||
| // level_traverse.append( | |||||
| // get_traverse_shallow_structure(traverse_fn, branch)) | |||||
| // else: | |||||
| // level_traverse.append(False) | |||||
| // return _sequence_like(structure, level_traverse) | |||||
| //def yield_flat_paths(nest): | |||||
| // """Yields paths for some nested structure. | |||||
| // Paths are lists of objects which can be str-converted, which may include | |||||
| // integers or other types which are used as indices in a dict. | |||||
| // The flat list will be in the corresponding order as if you called | |||||
| // `snt.nest.flatten` on the structure. This is handy for naming Tensors such | |||||
| // the TF scope structure matches the tuple structure. | |||||
| // E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` | |||||
| // ```shell | |||||
| // >>> nest.flatten(value) | |||||
| // [3, 23, 42] | |||||
| // >>> list(nest.yield_flat_paths(value)) | |||||
| // [('a',), ('b', 'c'), ('b', 'd')] | |||||
| // ``` | |||||
| // ```shell | |||||
| // >>> list(nest.yield_flat_paths({'a': [3]})) | |||||
| // [('a', 0)] | |||||
| // >>> list(nest.yield_flat_paths({'a': 3})) | |||||
| // [('a',)] | |||||
| // ``` | |||||
| // Args: | |||||
| // nest: the value to produce a flattened paths list for. | |||||
| // Yields: | |||||
| // Tuples containing index or key values which form the path to a specific | |||||
| // leaf value in the nested structure. | |||||
| // """ | |||||
| //# The _maybe_add_final_path_element function is used below in order to avoid | |||||
| //# adding trailing slashes when the sub-element recursed into is a leaf. | |||||
| // if isinstance(nest, (dict, _collections.Mapping)): | |||||
| // for key in _sorted(nest): | |||||
| // value = nest[key] | |||||
| // for sub_path in yield_flat_paths(value): | |||||
| // yield (key,) + sub_path | |||||
| // elif _is_namedtuple(nest): | |||||
| // for key in nest._fields: | |||||
| // value = getattr(nest, key) | |||||
| // for sub_path in yield_flat_paths(value): | |||||
| // yield (key,) + sub_path | |||||
| // elif isinstance(nest, _six.string_types): | |||||
| // yield () | |||||
| // elif isinstance(nest, _collections.Sequence): | |||||
| // for idx, value in enumerate(nest): | |||||
| // for sub_path in yield_flat_paths(value): | |||||
| // yield (idx,) + sub_path | |||||
| // else: | |||||
| // yield () | |||||
| //def flatten_with_joined_string_paths(structure, separator="/"): | |||||
| // """Returns a list of (string path, data element) tuples. | |||||
| // The order of tuples produced matches that of `nest.flatten`. This allows you | |||||
| // to flatten a nested structure while keeping information about where in the | |||||
| // structure each data element was located. See `nest.yield_flat_paths` | |||||
| // for more information. | |||||
| // Args: | |||||
| // structure: the nested structure to flatten. | |||||
| // separator: string to separate levels of hierarchy in the results, defaults | |||||
| // to '/'. | |||||
| // Returns: | |||||
| // A list of (string, data element) tuples. | |||||
| // """ | |||||
| // flat_paths = yield_flat_paths(structure) | |||||
| // def stringify_and_join(path_elements): | |||||
| // return separator.join(str(path_element) for path_element in path_elements) | |||||
| // flat_string_paths = [stringify_and_join(path) for path in flat_paths] | |||||
| // return list(zip(flat_string_paths, flatten(structure))) | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,852 @@ | |||||
| using System.Collections; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||||
| using Newtonsoft.Json.Linq; | |||||
| using Tensorflow; | |||||
| using Tensorflow.Util; | |||||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||||
| { | |||||
| /// <summary> | |||||
| /// excerpt of tensorflow/python/framework/util/nest_test.py | |||||
| /// </summary> | |||||
| [TestClass] | |||||
| public class NestTest : PythonTest | |||||
| { | |||||
| public class PointXY | |||||
| { | |||||
| public double x; | |||||
| public double y; | |||||
| } | |||||
| // if attr: | |||||
| // class BadAttr(object): | |||||
| // """Class that has a non-iterable __attrs_attrs__.""" | |||||
| // __attrs_attrs__ = None | |||||
| // @attr.s | |||||
| // class SampleAttr(object): | |||||
| // field1 = attr.ib() | |||||
| // field2 = attr.ib() | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testAttrsFlattenAndPack(self) : | |||||
| // if attr is None: | |||||
| // self.skipTest("attr module is unavailable.") | |||||
| // field_values = [1, 2] | |||||
| // sample_attr = NestTest.SampleAttr(* field_values) | |||||
| // self.assertFalse(nest._is_attrs(field_values)) | |||||
| // self.assertTrue(nest._is_attrs(sample_attr)) | |||||
| // flat = nest.flatten(sample_attr) | |||||
| // self.assertEqual(field_values, flat) | |||||
| // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||||
| // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||||
| // self.assertEqual(restructured_from_flat, sample_attr) | |||||
| //# Check that flatten fails if attributes are not iterable | |||||
| // with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||||
| // flat = nest.flatten(NestTest.BadAttr()) | |||||
| [TestMethod] | |||||
| public void testFlattenAndPack() | |||||
| { | |||||
| object structure = new object[] {new object[] {3, 4}, 5, new object[] {6, 7, new object[] {9, 10}, 8}}; | |||||
| var flat = new List<object> {"a", "b", "c", "d", "e", "f", "g", "h"}; | |||||
| self.assertEqual(nest.flatten(structure), new[] {3, 4, 5, 6, 7, 9, 10, 8}); | |||||
| self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(), | |||||
| JArray.FromObject(new object[] {new object[] {"a", "b"}, "c", new object[] {"d", "e", new object[] {"f", "g"}, "h"}}).ToString()); | |||||
| structure = new object[] { new Hashtable {["x"] = 4, ["y"] = 2}, new object[] { new object[] { new Hashtable { ["x"] = 1,["y"] = 0}, }, }}; | |||||
| flat = new List<object> { 4, 2, 1, 0}; | |||||
| self.assertEqual(nest.flatten(structure), flat); | |||||
| // restructured_from_flat = nest.pack_sequence_as(structure, flat) | |||||
| // self.assertEqual(restructured_from_flat, structure) | |||||
| // self.assertEqual(restructured_from_flat[0].x, 4) | |||||
| // self.assertEqual(restructured_from_flat[0].y, 2) | |||||
| // self.assertEqual(restructured_from_flat[1][0][0].x, 1) | |||||
| // self.assertEqual(restructured_from_flat[1][0][0].y, 0) | |||||
| // self.assertEqual([5], nest.flatten(5)) | |||||
| // self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) | |||||
| // self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) | |||||
| // self.assertEqual( | |||||
| // np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) | |||||
| // with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): | |||||
| // nest.pack_sequence_as("scalar", [4, 5]) | |||||
| // with self.assertRaisesRegexp(TypeError, "flat_sequence"): | |||||
| // nest.pack_sequence_as([4, 5], "bad_sequence") | |||||
| // with self.assertRaises(ValueError): | |||||
| // nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) | |||||
| } | |||||
| // @parameterized.parameters({"mapping_type": collections.OrderedDict | |||||
| // }, | |||||
| // {"mapping_type": _CustomMapping | |||||
| //}) | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testFlattenDictOrder(self, mapping_type) : | |||||
| // """`flatten` orders dicts by key, including OrderedDicts.""" | |||||
| // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||||
| // plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||||
| // ordered_flat = nest.flatten(ordered) | |||||
| // plain_flat = nest.flatten(plain) | |||||
| // self.assertEqual([0, 1, 2, 3], ordered_flat) | |||||
| // self.assertEqual([0, 1, 2, 3], plain_flat) | |||||
| // @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||||
| // {"mapping_type": _CustomMapping}) | |||||
| // def testPackDictOrder(self, mapping_type): | |||||
| // """Packing orders dicts by key, including OrderedDicts.""" | |||||
| // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||||
| // plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||||
| // seq = [0, 1, 2, 3] | |||||
| //custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||||
| //plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||||
| // self.assertIsInstance(custom_reconstruction, mapping_type) | |||||
| // self.assertIsInstance(plain_reconstruction, dict) | |||||
| // self.assertEqual( | |||||
| // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||||
| // custom_reconstruction) | |||||
| // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||||
| // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testFlattenAndPack_withDicts(self) : | |||||
| // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||||
| // mess = [ | |||||
| // "z", | |||||
| // NestTest.Abc(3, 4), { | |||||
| // "d": _CustomMapping({ | |||||
| // 41: 4 | |||||
| // }), | |||||
| // "c": [ | |||||
| // 1, | |||||
| // collections.OrderedDict([ | |||||
| // ("b", 3), | |||||
| // ("a", 2), | |||||
| // ]), | |||||
| // ], | |||||
| // "b": 5 | |||||
| // }, 17 | |||||
| // ] | |||||
| // flattened = nest.flatten(mess) | |||||
| // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||||
| // structure_of_mess = [ | |||||
| // 14, | |||||
| // NestTest.Abc("a", True), | |||||
| // { | |||||
| // "d": _CustomMapping({ | |||||
| // 41: 42 | |||||
| // }), | |||||
| // "c": [ | |||||
| // 0, | |||||
| // collections.OrderedDict([ | |||||
| // ("b", 9), | |||||
| // ("a", 8), | |||||
| // ]), | |||||
| // ], | |||||
| // "b": 3 | |||||
| // }, | |||||
| // "hi everybody", | |||||
| // ] | |||||
| // unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||||
| // self.assertEqual(unflattened, mess) | |||||
| // # Check also that the OrderedDict was created, with the correct key order. | |||||
| //unflattened_ordered_dict = unflattened[2]["c"][1] | |||||
| // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||||
| // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||||
| // unflattened_custom_mapping = unflattened[2]["d"] | |||||
| // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||||
| // self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||||
| // def testFlatten_numpyIsNotFlattened(self): | |||||
| // structure = np.array([1, 2, 3]) | |||||
| // flattened = nest.flatten(structure) | |||||
| // self.assertEqual(len(flattened), 1) | |||||
| // def testFlatten_stringIsNotFlattened(self): | |||||
| // structure = "lots of letters" | |||||
| // flattened = nest.flatten(structure) | |||||
| // self.assertEqual(len(flattened), 1) | |||||
| // unflattened = nest.pack_sequence_as("goodbye", flattened) | |||||
| // self.assertEqual(structure, unflattened) | |||||
| // def testPackSequenceAs_notIterableError(self) : | |||||
| // with self.assertRaisesRegexp(TypeError, | |||||
| // "flat_sequence must be a sequence"): | |||||
| // nest.pack_sequence_as("hi", "bye") | |||||
| // def testPackSequenceAs_wrongLengthsError(self): | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // "Structure had 2 elements, but flat_sequence had 3 elements."): | |||||
| // nest.pack_sequence_as(["hello", "world"], | |||||
| // ["and", "goodbye", "again"]) | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testIsSequence(self): | |||||
| // self.assertFalse(nest.is_sequence("1234")) | |||||
| // self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) | |||||
| // self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) | |||||
| // self.assertTrue(nest.is_sequence([])) | |||||
| // self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) | |||||
| // self.assertFalse(nest.is_sequence(set([1, 2]))) | |||||
| // ones = array_ops.ones([2, 3]) | |||||
| // self.assertFalse(nest.is_sequence(ones)) | |||||
| // self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) | |||||
| // self.assertFalse(nest.is_sequence(np.ones((4, 5)))) | |||||
| // @parameterized.parameters({"mapping_type": _CustomMapping}, | |||||
| // {"mapping_type": dict}) | |||||
| // def testFlattenDictItems(self, mapping_type): | |||||
| // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||||
| // flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||||
| // self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||||
| // with self.assertRaises(TypeError): | |||||
| // nest.flatten_dict_items(4) | |||||
| // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||||
| // with self.assertRaisesRegexp(ValueError, "not unique"): | |||||
| // nest.flatten_dict_items(bad_dictionary) | |||||
| // another_bad_dictionary = mapping_type({ | |||||
| // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||||
| // }) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||||
| // nest.flatten_dict_items(another_bad_dictionary) | |||||
| //# pylint does not correctly recognize these as class names and | |||||
| //# suggests to use variable style under_score naming. | |||||
| //# pylint: disable=invalid-name | |||||
| // Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||||
| // Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||||
| // SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||||
| // SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||||
| // SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||||
| // SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||||
| // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||||
| // NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||||
| // # pylint: enable=invalid-name | |||||
| // class SameNamedType1(SameNameab): | |||||
| // pass | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testAssertSameStructure(self): | |||||
| // structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
| // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
| // structure_different_num_elements = ("spam", "eggs") | |||||
| // structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||||
| // nest.assert_same_structure(structure1, structure2) | |||||
| // nest.assert_same_structure("abc", 1.0) | |||||
| // nest.assert_same_structure("abc", np.array([0, 1])) | |||||
| // nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||||
| // "First structure:.*?\n\n" | |||||
| // "Second structure:.*\n\n" | |||||
| // "More specifically: Substructure " | |||||
| // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||||
| // 'substructure "type=str str=spam" is not\n' | |||||
| // "Entire first structure:\n" | |||||
| // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||||
| // "Entire second structure:\n" | |||||
| // r"\(\., \.\)")): | |||||
| // nest.assert_same_structure(structure1, structure_different_num_elements) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||||
| // "First structure:.*?\n\n" | |||||
| // "Second structure:.*\n\n" | |||||
| // r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
| // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||||
| // "is not")): | |||||
| // nest.assert_same_structure([0, 1], np.array([0, 1])) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||||
| // "First structure:.*?\n\n" | |||||
| // "Second structure:.*\n\n" | |||||
| // r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
| // 'is a sequence, while substructure "type=int str=0" ' | |||||
| // "is not")): | |||||
| // nest.assert_same_structure(0, [0, 1]) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("don't have the same nested structure\\.\n\n" | |||||
| // "First structure: .*?\n\nSecond structure: ")): | |||||
| // nest.assert_same_structure(structure1, structure_different_nesting) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||||
| // NestTest.Named0ab("a", "b")) | |||||
| // nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
| // NestTest.Named0ab("a", "b")) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("don't have the same nested structure\\.\n\n" | |||||
| // "First structure: .*?\n\nSecond structure: ")): | |||||
| // nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
| // NestTest.Named0ab([3], 4)) | |||||
| // with self.assertRaisesRegexp( | |||||
| // ValueError, | |||||
| // ("don't have the same nested structure\\.\n\n" | |||||
| // "First structure: .*?\n\nSecond structure: ")): | |||||
| // nest.assert_same_structure([[3], 4], [3, [4]]) | |||||
| // structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
| // with self.assertRaisesRegexp(TypeError, | |||||
| // "don't have the same sequence type"): | |||||
| // nest.assert_same_structure(structure1, structure1_list) | |||||
| // nest.assert_same_structure(structure1, structure2, check_types= False) | |||||
| // nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||||
| // with self.assertRaisesRegexp(ValueError, | |||||
| // "don't have the same set of keys"): | |||||
| // nest.assert_same_structure({"a": 1}, {"b": 1}) | |||||
| // nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||||
| // NestTest.SameNameab2(2, 3)) | |||||
| // # This assertion is expected to pass: two namedtuples with the same | |||||
| // # name and field names are considered to be identical. | |||||
| // nest.assert_same_structure( | |||||
| // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||||
| // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||||
| // expected_message = "The two structures don't have the same.*" | |||||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| // nest.assert_same_structure( | |||||
| // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||||
| // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||||
| // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||||
| // def testHeterogeneousComparison(self): | |||||
| // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3)) | |||||
| // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testMapStructure(self) : | |||||
| // structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
| // structure2 = (((7, 8), 9), 10, (11, 12)) | |||||
| // structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) | |||||
| // nest.assert_same_structure(structure1, structure1_plus1) | |||||
| // self.assertAllEqual( | |||||
| // [2, 3, 4, 5, 6, 7], | |||||
| // nest.flatten(structure1_plus1)) | |||||
| // structure1_plus_structure2 = nest.map_structure( | |||||
| // lambda x, y: x + y, structure1, structure2) | |||||
| // self.assertEqual( | |||||
| // (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), | |||||
| // structure1_plus_structure2) | |||||
| // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||||
| // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||||
| // # Empty structures | |||||
| // self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||||
| // self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||||
| // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||||
| // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||||
| // NestTest.EmptyNT())) | |||||
| // # This is checking actual equality of types, empty list != empty tuple | |||||
| // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||||
| // with self.assertRaisesRegexp(TypeError, "callable"): | |||||
| // nest.map_structure("bad", structure1_plus1) | |||||
| // with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||||
| // nest.map_structure(lambda x: x) | |||||
| // with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||||
| // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| // nest.map_structure(lambda x, y: None, 3, (3,)) | |||||
| // with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||||
| // structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
| // with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
| // nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||||
| // nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||||
| // check_types=False) | |||||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||||
| // check_types=False) | |||||
| // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
| // nest.map_structure(lambda x: None, structure1, foo="a") | |||||
| // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
| // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||||
| // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| // def testMapStructureWithStrings(self) : | |||||
| // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||||
| // inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||||
| // out = nest.map_structure(lambda string, repeats: string* repeats, | |||||
| // inp_a, | |||||
| // inp_b) | |||||
| // self.assertEqual("foofoo", out.a) | |||||
| // self.assertEqual("bar", out.b[0]) | |||||
| // self.assertEqual("bazbazbaz", out.b[1]) | |||||
| // nt = NestTest.ABTuple(a=("something", "something_else"), | |||||
| // b="yet another thing") | |||||
| // rev_nt = nest.map_structure(lambda x: x[::- 1], nt) | |||||
| // # Check the output is the correct structure, and all strings are reversed. | |||||
| // nest.assert_same_structure(nt, rev_nt) | |||||
| // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0]) | |||||
| // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1]) | |||||
| // self.assertEqual(nt.b[::- 1], rev_nt.b) | |||||
| // @test_util.run_deprecated_v1 | |||||
| // def testMapStructureOverPlaceholders(self) : | |||||
| // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
| // array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
| // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
| // array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
| // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||||
| // nest.assert_same_structure(output, inp_a) | |||||
| // self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||||
| // self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||||
| // feed_dict = { | |||||
| // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||||
| // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||||
| // } | |||||
| // with self.cached_session() as sess: | |||||
| // output_np = sess.run(output, feed_dict=feed_dict) | |||||
| // self.assertAllClose(output_np[0], | |||||
| // feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||||
| // self.assertAllClose(output_np[1], | |||||
| // feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||||
| // def testAssertShallowStructure(self): | |||||
| // inp_ab = ["a", "b"] | |||||
| //inp_abc = ["a", "b", "c"] | |||||
| //expected_message = ( | |||||
| // "The two structures don't have the same sequence length. Input " | |||||
| // "structure has length 2, while shallow structure has length 3.") | |||||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| // nest.assert_shallow_structure(inp_abc, inp_ab) | |||||
| // inp_ab1 = [(1, 1), (2, 2)] | |||||
| // inp_ab2 = [[1, 1], [2, 2]] | |||||
| // expected_message = ( | |||||
| // "The two structures don't have the same sequence type. Input structure " | |||||
| // "has type <(type|class) 'tuple'>, while shallow structure has type " | |||||
| // "<(type|class) 'list'>.") | |||||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False) | |||||
| // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||||
| // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||||
| // expected_message = ( | |||||
| // r"The two structures don't have the same keys. Input " | |||||
| // r"structure has keys \['c'\], while shallow structure has " | |||||
| // r"keys \['d'\].") | |||||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
| // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||||
| // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||||
| // nest.assert_shallow_structure(inp_ab, inp_ba) | |||||
| // # This assertion is expected to pass: two namedtuples with the same | |||||
| //# name and field names are considered to be identical. | |||||
| //inp_shallow = NestTest.SameNameab(1, 2) | |||||
| // inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||||
| // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||||
| // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||||
| // def testFlattenUpTo(self): | |||||
| // # Shallow tree ends at scalar. | |||||
| // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||||
| // shallow_tree = [[True, True], [False, True]] | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||||
| // self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||||
| //# Shallow tree ends at string. | |||||
| // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||||
| // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // input_tree_flattened = nest.flatten(input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| // [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||||
| // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||||
| // # Make sure dicts are correctly flattened, yielding values, not keys. | |||||
| //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||||
| // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| // [1, { "c": 2}, 3, (4, 5)]) | |||||
| // # Namedtuples. | |||||
| // ab_tuple = NestTest.ABTuple | |||||
| // input_tree = ab_tuple(a =[0, 1], b = 2) | |||||
| // shallow_tree = ab_tuple(a= 0, b= 1) | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| // [[0, 1], 2]) | |||||
| // # Nested dicts, OrderedDicts and namedtuples. | |||||
| // input_tree = collections.OrderedDict( | |||||
| // [("a", ab_tuple(a =[0, {"b": 1}], b=2)), | |||||
| // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||||
| // shallow_tree = input_tree | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||||
| // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| // [ab_tuple(a =[0, { "b": 1}], b=2), | |||||
| // 3, | |||||
| // collections.OrderedDict([("f", 4)])]) | |||||
| // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| // input_tree) | |||||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| // [ab_tuple(a =[0, {"b": 1}], b=2), | |||||
| // {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||||
| // ## Shallow non-list edge-case. | |||||
| // # Using iterable elements. | |||||
| // input_tree = ["input_tree"] | |||||
| //shallow_tree = "shallow_tree" | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // input_tree = ["input_tree_0", "input_tree_1"] | |||||
| //shallow_tree = "shallow_tree" | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // # Using non-iterable elements. | |||||
| //input_tree = [0] | |||||
| //shallow_tree = 9 | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // input_tree = [0, 1] | |||||
| //shallow_tree = 9 | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // ## Both non-list edge-case. | |||||
| //# Using iterable elements. | |||||
| //input_tree = "input_tree" | |||||
| // shallow_tree = "shallow_tree" | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // # Using non-iterable elements. | |||||
| //input_tree = 0 | |||||
| // shallow_tree = 0 | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| // ## Input non-list edge-case. | |||||
| //# Using iterable elements. | |||||
| //input_tree = "input_tree" | |||||
| // shallow_tree = ["shallow_tree"] | |||||
| //expected_message = ("If shallow structure is a sequence, input must also " | |||||
| // "be a sequence. Input has type: <(type|class) 'str'>.") | |||||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| // input_tree = "input_tree" | |||||
| // shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||||
| //with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| //# Using non-iterable elements. | |||||
| // input_tree = 0 | |||||
| // shallow_tree = [9] | |||||
| //expected_message = ("If shallow structure is a sequence, input must also " | |||||
| // "be a sequence. Input has type: <(type|class) 'int'>.") | |||||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| // input_tree = 0 | |||||
| // shallow_tree = [9, 8] | |||||
| //with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| // def testMapStructureUpTo(self) : | |||||
| // # Named tuples. | |||||
| // ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||||
| // op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||||
| // inp_val = ab_tuple(a= 2, b= 3) | |||||
| // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3)) | |||||
| // out = nest.map_structure_up_to( | |||||
| // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||||
| // self.assertEqual(out.a, 6) | |||||
| // self.assertEqual(out.b, 15) | |||||
| // # Lists. | |||||
| // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||||
| // name_list = ["evens", ["odds", "primes"]] | |||||
| // out = nest.map_structure_up_to( | |||||
| // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||||
| // name_list, data_list) | |||||
| // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||||
| // # Dicts. | |||||
| // inp_val = dict(a= 2, b= 3) | |||||
| // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||||
| // out = nest.map_structure_up_to( | |||||
| // inp_val, | |||||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| // self.assertEqual(out["a"], 6) | |||||
| // self.assertEqual(out["b"], 15) | |||||
| // # Non-equal dicts. | |||||
| // inp_val = dict(a= 2, b= 3) | |||||
| // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||||
| // with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
| // nest.map_structure_up_to( | |||||
| // inp_val, | |||||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| // # Dict+custom mapping. | |||||
| // inp_val = dict(a= 2, b= 3) | |||||
| // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||||
| // out = nest.map_structure_up_to( | |||||
| // inp_val, | |||||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| // self.assertEqual(out["a"], 6) | |||||
| // self.assertEqual(out["b"], 15) | |||||
| // # Non-equal dict/mapping. | |||||
| // inp_val = dict(a= 2, b= 3) | |||||
| // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||||
| // with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
| // nest.map_structure_up_to( | |||||
| // inp_val, | |||||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| // def testGetTraverseShallowStructure(self): | |||||
| // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||||
| // scalar_traverse_r = nest.get_traverse_shallow_structure( | |||||
| // lambda s: not isinstance(s, tuple), | |||||
| // scalar_traverse_input) | |||||
| // self.assertEqual(scalar_traverse_r, | |||||
| // [True, True, False, [True, True], {"a": False}, []]) | |||||
| // nest.assert_shallow_structure(scalar_traverse_r, | |||||
| // scalar_traverse_input) | |||||
| // structure_traverse_input = [(1, [2]), ([1], 2)] | |||||
| // structure_traverse_r = nest.get_traverse_shallow_structure( | |||||
| // lambda s: (True, False) if isinstance(s, tuple) else True, | |||||
| // structure_traverse_input) | |||||
| // self.assertEqual(structure_traverse_r, | |||||
| // [(True, False), ([True], False)]) | |||||
| // nest.assert_shallow_structure(structure_traverse_r, | |||||
| // structure_traverse_input) | |||||
| // with self.assertRaisesRegexp(TypeError, "returned structure"): | |||||
| // nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||||
| // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||||
| // nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||||
| // with self.assertRaisesRegexp( | |||||
| // TypeError, "didn't return a depth=1 structure of bools"): | |||||
| // nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||||
| // def testYieldFlatStringPaths(self): | |||||
| // for inputs_expected in ({"inputs": [], "expected": []}, | |||||
| // {"inputs": 3, "expected": [()]}, | |||||
| // {"inputs": [3], "expected": [(0,)]}, | |||||
| // {"inputs": {"a": 3}, "expected": [("a",)]}, | |||||
| // {"inputs": {"a": {"b": 4}}, | |||||
| // "expected": [("a", "b")]}, | |||||
| // {"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||||
| // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||||
| // {"inputs": [{"a": [(23, 42)]}], | |||||
| // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||||
| // {"inputs": [{"a": ([23], 42)}], | |||||
| // "expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||||
| // {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||||
| // "expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||||
| // {"inputs": {"0": [{"1": 23}]}, | |||||
| // "expected": [("0", 0, "1")]}): | |||||
| // inputs = inputs_expected["inputs"] | |||||
| // expected = inputs_expected["expected"] | |||||
| // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||||
| // def testFlattenWithStringPaths(self): | |||||
| // for inputs_expected in ( | |||||
| // {"inputs": [], "expected": []}, | |||||
| // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||||
| // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||||
| // inputs = inputs_expected["inputs"] | |||||
| // expected = inputs_expected["expected"] | |||||
| // self.assertEqual( | |||||
| // nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||||
| // expected) | |||||
| // # Need a separate test for namedtuple as we can't declare tuple definitions | |||||
| // # in the @parameterized arguments. | |||||
| // def testFlattenNamedTuple(self): | |||||
| // # pylint: disable=invalid-name | |||||
| // Foo = collections.namedtuple("Foo", ["a", "b"]) | |||||
| // Bar = collections.namedtuple("Bar", ["c", "d"]) | |||||
| // # pylint: enable=invalid-name | |||||
| // test_cases = [ | |||||
| // (Foo(a = 3, b = Bar(c = 23, d = 42)), | |||||
| // [("a", 3), ("b/c", 23), ("b/d", 42)]), | |||||
| // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")), | |||||
| // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||||
| // (Bar(c = 42, d = 43), | |||||
| // [("c", 42), ("d", 43)]), | |||||
| // (Bar(c =[42], d = 43), | |||||
| // [("c/0", 42), ("d", 43)]), | |||||
| // ] | |||||
| // for inputs, expected in test_cases: | |||||
| // self.assertEqual( | |||||
| // list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||||
| // @parameterized.named_parameters( | |||||
| // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||||
| // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||||
| // {"a": ("a", 4), "b": ("b", 6)}), | |||||
| // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||||
| // ("nested", | |||||
| // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||||
| // {"a": [("a/0", 10), ("a/1", 12)], | |||||
| // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||||
| // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||||
| // def format_sum(path, * values): | |||||
| // return (path, sum(values)) | |||||
| // result = nest.map_structure_with_paths(format_sum, s1, s2, | |||||
| // check_types=check_types) | |||||
| // self.assertEqual(expected, result) | |||||
| // @parameterized.named_parameters( | |||||
| // ("tuples", (1, 2), (3, 4, 5), ValueError), | |||||
| // ("dicts", {"a": 1}, {"b": 2}, ValueError), | |||||
| // ("mixed", (1, 2), [3, 4], TypeError), | |||||
| // ("nested", | |||||
| // {"a": [2, 3], "b": [1, 3]}, | |||||
| // {"b": [5, 6, 7], "a": [8, 9]}, | |||||
| // ValueError | |||||
| // )) | |||||
| // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||||
| // with self.assertRaises(error_type): | |||||
| // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2) | |||||
| //class NestBenchmark(test.Benchmark): | |||||
| // def run_and_report(self, s1, s2, name): | |||||
| // burn_iter, test_iter = 100, 30000 | |||||
| // for _ in xrange(burn_iter) : | |||||
| // nest.assert_same_structure(s1, s2) | |||||
| // t0 = time.time() | |||||
| // for _ in xrange(test_iter) : | |||||
| // nest.assert_same_structure(s1, s2) | |||||
| // t1 = time.time() | |||||
| // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||||
| // name=name) | |||||
| // def benchmark_assert_structure(self): | |||||
| // s1 = (((1, 2), 3), 4, (5, 6)) | |||||
| // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
| // self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||||
| // s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||||
| // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||||
| // self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||||
| //if __name__ == "__main__": | |||||
| // test.main() | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,883 @@ | |||||
| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================== | |||||
| """Tests for utilities working with arbitrarily nested structures.""" | |||||
| from __future__ import absolute_import | |||||
| from __future__ import division | |||||
| from __future__ import print_function | |||||
| import collections | |||||
| import time | |||||
| from absl.testing import parameterized | |||||
| import numpy as np | |||||
| from six.moves import xrange # pylint: disable=redefined-builtin | |||||
| from tensorflow.python.framework import constant_op | |||||
| from tensorflow.python.framework import dtypes | |||||
| from tensorflow.python.framework import test_util | |||||
| from tensorflow.python.ops import array_ops | |||||
| from tensorflow.python.ops import math_ops | |||||
| from tensorflow.python.platform import test | |||||
| from tensorflow.python.util import nest | |||||
| try: | |||||
| import attr # pylint:disable=g-import-not-at-top | |||||
| except ImportError: | |||||
| attr = None | |||||
| class _CustomMapping(collections.Mapping): | |||||
| def __init__(self, *args, **kwargs): | |||||
| self._wrapped = dict(*args, **kwargs) | |||||
| def __getitem__(self, key): | |||||
| return self._wrapped[key] | |||||
| def __iter__(self): | |||||
| return iter(self._wrapped) | |||||
| def __len__(self): | |||||
| return len(self._wrapped) | |||||
| class NestTest(parameterized.TestCase, test.TestCase): | |||||
| PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name | |||||
| if attr: | |||||
| class BadAttr(object): | |||||
| """Class that has a non-iterable __attrs_attrs__.""" | |||||
| __attrs_attrs__ = None | |||||
| @attr.s | |||||
| class SampleAttr(object): | |||||
| field1 = attr.ib() | |||||
| field2 = attr.ib() | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testAttrsFlattenAndPack(self): | |||||
| if attr is None: | |||||
| self.skipTest("attr module is unavailable.") | |||||
| field_values = [1, 2] | |||||
| sample_attr = NestTest.SampleAttr(*field_values) | |||||
| self.assertFalse(nest._is_attrs(field_values)) | |||||
| self.assertTrue(nest._is_attrs(sample_attr)) | |||||
| flat = nest.flatten(sample_attr) | |||||
| self.assertEqual(field_values, flat) | |||||
| restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||||
| self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||||
| self.assertEqual(restructured_from_flat, sample_attr) | |||||
| # Check that flatten fails if attributes are not iterable | |||||
| with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||||
| flat = nest.flatten(NestTest.BadAttr()) | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testFlattenAndPack(self): | |||||
| structure = ((3, 4), 5, (6, 7, (9, 10), 8)) | |||||
| flat = ["a", "b", "c", "d", "e", "f", "g", "h"] | |||||
| self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) | |||||
| self.assertEqual( | |||||
| nest.pack_sequence_as(structure, flat), (("a", "b"), "c", | |||||
| ("d", "e", ("f", "g"), "h"))) | |||||
| structure = (NestTest.PointXY(x=4, y=2), | |||||
| ((NestTest.PointXY(x=1, y=0),),)) | |||||
| flat = [4, 2, 1, 0] | |||||
| self.assertEqual(nest.flatten(structure), flat) | |||||
| restructured_from_flat = nest.pack_sequence_as(structure, flat) | |||||
| self.assertEqual(restructured_from_flat, structure) | |||||
| self.assertEqual(restructured_from_flat[0].x, 4) | |||||
| self.assertEqual(restructured_from_flat[0].y, 2) | |||||
| self.assertEqual(restructured_from_flat[1][0][0].x, 1) | |||||
| self.assertEqual(restructured_from_flat[1][0][0].y, 0) | |||||
| self.assertEqual([5], nest.flatten(5)) | |||||
| self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) | |||||
| self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) | |||||
| self.assertEqual( | |||||
| np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) | |||||
| with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): | |||||
| nest.pack_sequence_as("scalar", [4, 5]) | |||||
| with self.assertRaisesRegexp(TypeError, "flat_sequence"): | |||||
| nest.pack_sequence_as([4, 5], "bad_sequence") | |||||
| with self.assertRaises(ValueError): | |||||
| nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) | |||||
| @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||||
| {"mapping_type": _CustomMapping}) | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testFlattenDictOrder(self, mapping_type): | |||||
| """`flatten` orders dicts by key, including OrderedDicts.""" | |||||
| ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||||
| plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||||
| ordered_flat = nest.flatten(ordered) | |||||
| plain_flat = nest.flatten(plain) | |||||
| self.assertEqual([0, 1, 2, 3], ordered_flat) | |||||
| self.assertEqual([0, 1, 2, 3], plain_flat) | |||||
| @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||||
| {"mapping_type": _CustomMapping}) | |||||
| def testPackDictOrder(self, mapping_type): | |||||
| """Packing orders dicts by key, including OrderedDicts.""" | |||||
| custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||||
| plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||||
| seq = [0, 1, 2, 3] | |||||
| custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||||
| plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||||
| self.assertIsInstance(custom_reconstruction, mapping_type) | |||||
| self.assertIsInstance(plain_reconstruction, dict) | |||||
| self.assertEqual( | |||||
| mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||||
| custom_reconstruction) | |||||
| self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||||
| Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testFlattenAndPack_withDicts(self): | |||||
| # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||||
| mess = [ | |||||
| "z", | |||||
| NestTest.Abc(3, 4), { | |||||
| "d": _CustomMapping({ | |||||
| 41: 4 | |||||
| }), | |||||
| "c": [ | |||||
| 1, | |||||
| collections.OrderedDict([ | |||||
| ("b", 3), | |||||
| ("a", 2), | |||||
| ]), | |||||
| ], | |||||
| "b": 5 | |||||
| }, 17 | |||||
| ] | |||||
| flattened = nest.flatten(mess) | |||||
| self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||||
| structure_of_mess = [ | |||||
| 14, | |||||
| NestTest.Abc("a", True), | |||||
| { | |||||
| "d": _CustomMapping({ | |||||
| 41: 42 | |||||
| }), | |||||
| "c": [ | |||||
| 0, | |||||
| collections.OrderedDict([ | |||||
| ("b", 9), | |||||
| ("a", 8), | |||||
| ]), | |||||
| ], | |||||
| "b": 3 | |||||
| }, | |||||
| "hi everybody", | |||||
| ] | |||||
| unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||||
| self.assertEqual(unflattened, mess) | |||||
| # Check also that the OrderedDict was created, with the correct key order. | |||||
| unflattened_ordered_dict = unflattened[2]["c"][1] | |||||
| self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||||
| self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||||
| unflattened_custom_mapping = unflattened[2]["d"] | |||||
| self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||||
| self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||||
| def testFlatten_numpyIsNotFlattened(self): | |||||
| structure = np.array([1, 2, 3]) | |||||
| flattened = nest.flatten(structure) | |||||
| self.assertEqual(len(flattened), 1) | |||||
| def testFlatten_stringIsNotFlattened(self): | |||||
| structure = "lots of letters" | |||||
| flattened = nest.flatten(structure) | |||||
| self.assertEqual(len(flattened), 1) | |||||
| unflattened = nest.pack_sequence_as("goodbye", flattened) | |||||
| self.assertEqual(structure, unflattened) | |||||
| def testPackSequenceAs_notIterableError(self): | |||||
| with self.assertRaisesRegexp(TypeError, | |||||
| "flat_sequence must be a sequence"): | |||||
| nest.pack_sequence_as("hi", "bye") | |||||
| def testPackSequenceAs_wrongLengthsError(self): | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| "Structure had 2 elements, but flat_sequence had 3 elements."): | |||||
| nest.pack_sequence_as(["hello", "world"], | |||||
| ["and", "goodbye", "again"]) | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testIsSequence(self): | |||||
| self.assertFalse(nest.is_sequence("1234")) | |||||
| self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) | |||||
| self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) | |||||
| self.assertTrue(nest.is_sequence([])) | |||||
| self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) | |||||
| self.assertFalse(nest.is_sequence(set([1, 2]))) | |||||
| ones = array_ops.ones([2, 3]) | |||||
| self.assertFalse(nest.is_sequence(ones)) | |||||
| self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) | |||||
| self.assertFalse(nest.is_sequence(np.ones((4, 5)))) | |||||
| @parameterized.parameters({"mapping_type": _CustomMapping}, | |||||
| {"mapping_type": dict}) | |||||
| def testFlattenDictItems(self, mapping_type): | |||||
| dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||||
| flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||||
| self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||||
| with self.assertRaises(TypeError): | |||||
| nest.flatten_dict_items(4) | |||||
| bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||||
| with self.assertRaisesRegexp(ValueError, "not unique"): | |||||
| nest.flatten_dict_items(bad_dictionary) | |||||
| another_bad_dictionary = mapping_type({ | |||||
| (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||||
| }) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||||
| nest.flatten_dict_items(another_bad_dictionary) | |||||
| # pylint does not correctly recognize these as class names and | |||||
| # suggests to use variable style under_score naming. | |||||
| # pylint: disable=invalid-name | |||||
| Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||||
| Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||||
| SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||||
| SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||||
| SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||||
| SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||||
| SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||||
| NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||||
| # pylint: enable=invalid-name | |||||
| class SameNamedType1(SameNameab): | |||||
| pass | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testAssertSameStructure(self): | |||||
| structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
| structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
| structure_different_num_elements = ("spam", "eggs") | |||||
| structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||||
| nest.assert_same_structure(structure1, structure2) | |||||
| nest.assert_same_structure("abc", 1.0) | |||||
| nest.assert_same_structure("abc", np.array([0, 1])) | |||||
| nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("The two structures don't have the same nested structure\\.\n\n" | |||||
| "First structure:.*?\n\n" | |||||
| "Second structure:.*\n\n" | |||||
| "More specifically: Substructure " | |||||
| r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||||
| 'substructure "type=str str=spam" is not\n' | |||||
| "Entire first structure:\n" | |||||
| r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||||
| "Entire second structure:\n" | |||||
| r"\(\., \.\)")): | |||||
| nest.assert_same_structure(structure1, structure_different_num_elements) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("The two structures don't have the same nested structure\\.\n\n" | |||||
| "First structure:.*?\n\n" | |||||
| "Second structure:.*\n\n" | |||||
| r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
| r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||||
| "is not")): | |||||
| nest.assert_same_structure([0, 1], np.array([0, 1])) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("The two structures don't have the same nested structure\\.\n\n" | |||||
| "First structure:.*?\n\n" | |||||
| "Second structure:.*\n\n" | |||||
| r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||||
| 'is a sequence, while substructure "type=int str=0" ' | |||||
| "is not")): | |||||
| nest.assert_same_structure(0, [0, 1]) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("don't have the same nested structure\\.\n\n" | |||||
| "First structure: .*?\n\nSecond structure: ")): | |||||
| nest.assert_same_structure(structure1, structure_different_nesting) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||||
| NestTest.Named0ab("a", "b")) | |||||
| nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
| NestTest.Named0ab("a", "b")) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("don't have the same nested structure\\.\n\n" | |||||
| "First structure: .*?\n\nSecond structure: ")): | |||||
| nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||||
| NestTest.Named0ab([3], 4)) | |||||
| with self.assertRaisesRegexp( | |||||
| ValueError, | |||||
| ("don't have the same nested structure\\.\n\n" | |||||
| "First structure: .*?\n\nSecond structure: ")): | |||||
| nest.assert_same_structure([[3], 4], [3, [4]]) | |||||
| structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
| with self.assertRaisesRegexp(TypeError, | |||||
| "don't have the same sequence type"): | |||||
| nest.assert_same_structure(structure1, structure1_list) | |||||
| nest.assert_same_structure(structure1, structure2, check_types=False) | |||||
| nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||||
| with self.assertRaisesRegexp(ValueError, | |||||
| "don't have the same set of keys"): | |||||
| nest.assert_same_structure({"a": 1}, {"b": 1}) | |||||
| nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||||
| NestTest.SameNameab2(2, 3)) | |||||
| # This assertion is expected to pass: two namedtuples with the same | |||||
| # name and field names are considered to be identical. | |||||
| nest.assert_same_structure( | |||||
| NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||||
| NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||||
| expected_message = "The two structures don't have the same.*" | |||||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| nest.assert_same_structure( | |||||
| NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||||
| NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||||
| NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||||
| EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||||
| def testHeterogeneousComparison(self): | |||||
| nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) | |||||
| nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testMapStructure(self): | |||||
| structure1 = (((1, 2), 3), 4, (5, 6)) | |||||
| structure2 = (((7, 8), 9), 10, (11, 12)) | |||||
| structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) | |||||
| nest.assert_same_structure(structure1, structure1_plus1) | |||||
| self.assertAllEqual( | |||||
| [2, 3, 4, 5, 6, 7], | |||||
| nest.flatten(structure1_plus1)) | |||||
| structure1_plus_structure2 = nest.map_structure( | |||||
| lambda x, y: x + y, structure1, structure2) | |||||
| self.assertEqual( | |||||
| (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), | |||||
| structure1_plus_structure2) | |||||
| self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||||
| self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||||
| # Empty structures | |||||
| self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||||
| self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||||
| self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||||
| self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||||
| NestTest.EmptyNT())) | |||||
| # This is checking actual equality of types, empty list != empty tuple | |||||
| self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||||
| with self.assertRaisesRegexp(TypeError, "callable"): | |||||
| nest.map_structure("bad", structure1_plus1) | |||||
| with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||||
| nest.map_structure(lambda x: x) | |||||
| with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||||
| nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| nest.map_structure(lambda x, y: None, 3, (3,)) | |||||
| with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||||
| structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||||
| with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||||
| nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||||
| nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||||
| check_types=False) | |||||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||||
| check_types=False) | |||||
| with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
| nest.map_structure(lambda x: None, structure1, foo="a") | |||||
| with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||||
| nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||||
| ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||||
| def testMapStructureWithStrings(self): | |||||
| inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||||
| inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||||
| out = nest.map_structure(lambda string, repeats: string * repeats, | |||||
| inp_a, | |||||
| inp_b) | |||||
| self.assertEqual("foofoo", out.a) | |||||
| self.assertEqual("bar", out.b[0]) | |||||
| self.assertEqual("bazbazbaz", out.b[1]) | |||||
| nt = NestTest.ABTuple(a=("something", "something_else"), | |||||
| b="yet another thing") | |||||
| rev_nt = nest.map_structure(lambda x: x[::-1], nt) | |||||
| # Check the output is the correct structure, and all strings are reversed. | |||||
| nest.assert_same_structure(nt, rev_nt) | |||||
| self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) | |||||
| self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) | |||||
| self.assertEqual(nt.b[::-1], rev_nt.b) | |||||
| @test_util.run_deprecated_v1 | |||||
| def testMapStructureOverPlaceholders(self): | |||||
| inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
| array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
| inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||||
| array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||||
| output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||||
| nest.assert_same_structure(output, inp_a) | |||||
| self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||||
| self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||||
| feed_dict = { | |||||
| inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||||
| inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||||
| } | |||||
| with self.cached_session() as sess: | |||||
| output_np = sess.run(output, feed_dict=feed_dict) | |||||
| self.assertAllClose(output_np[0], | |||||
| feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||||
| self.assertAllClose(output_np[1], | |||||
| feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||||
| def testAssertShallowStructure(self): | |||||
| inp_ab = ["a", "b"] | |||||
| inp_abc = ["a", "b", "c"] | |||||
| expected_message = ( | |||||
| "The two structures don't have the same sequence length. Input " | |||||
| "structure has length 2, while shallow structure has length 3.") | |||||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| nest.assert_shallow_structure(inp_abc, inp_ab) | |||||
| inp_ab1 = [(1, 1), (2, 2)] | |||||
| inp_ab2 = [[1, 1], [2, 2]] | |||||
| expected_message = ( | |||||
| "The two structures don't have the same sequence type. Input structure " | |||||
| "has type <(type|class) 'tuple'>, while shallow structure has type " | |||||
| "<(type|class) 'list'>.") | |||||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
| nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) | |||||
| inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||||
| inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||||
| expected_message = ( | |||||
| r"The two structures don't have the same keys. Input " | |||||
| r"structure has keys \['c'\], while shallow structure has " | |||||
| r"keys \['d'\].") | |||||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||||
| nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||||
| inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||||
| inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||||
| nest.assert_shallow_structure(inp_ab, inp_ba) | |||||
| # This assertion is expected to pass: two namedtuples with the same | |||||
| # name and field names are considered to be identical. | |||||
| inp_shallow = NestTest.SameNameab(1, 2) | |||||
| inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||||
| nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||||
| nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||||
| def testFlattenUpTo(self): | |||||
| # Shallow tree ends at scalar. | |||||
| input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||||
| shallow_tree = [[True, True], [False, True]] | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||||
| self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||||
| # Shallow tree ends at string. | |||||
| input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||||
| shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| input_tree_flattened = nest.flatten(input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||||
| self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||||
| # Make sure dicts are correctly flattened, yielding values, not keys. | |||||
| input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||||
| shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| [1, {"c": 2}, 3, (4, 5)]) | |||||
| # Namedtuples. | |||||
| ab_tuple = NestTest.ABTuple | |||||
| input_tree = ab_tuple(a=[0, 1], b=2) | |||||
| shallow_tree = ab_tuple(a=0, b=1) | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| [[0, 1], 2]) | |||||
| # Nested dicts, OrderedDicts and namedtuples. | |||||
| input_tree = collections.OrderedDict( | |||||
| [("a", ab_tuple(a=[0, {"b": 1}], b=2)), | |||||
| ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||||
| shallow_tree = input_tree | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||||
| shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| [ab_tuple(a=[0, {"b": 1}], b=2), | |||||
| 3, | |||||
| collections.OrderedDict([("f", 4)])]) | |||||
| shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||||
| input_tree) | |||||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||||
| [ab_tuple(a=[0, {"b": 1}], b=2), | |||||
| {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||||
| ## Shallow non-list edge-case. | |||||
| # Using iterable elements. | |||||
| input_tree = ["input_tree"] | |||||
| shallow_tree = "shallow_tree" | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| input_tree = ["input_tree_0", "input_tree_1"] | |||||
| shallow_tree = "shallow_tree" | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| # Using non-iterable elements. | |||||
| input_tree = [0] | |||||
| shallow_tree = 9 | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| input_tree = [0, 1] | |||||
| shallow_tree = 9 | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| ## Both non-list edge-case. | |||||
| # Using iterable elements. | |||||
| input_tree = "input_tree" | |||||
| shallow_tree = "shallow_tree" | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| # Using non-iterable elements. | |||||
| input_tree = 0 | |||||
| shallow_tree = 0 | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||||
| ## Input non-list edge-case. | |||||
| # Using iterable elements. | |||||
| input_tree = "input_tree" | |||||
| shallow_tree = ["shallow_tree"] | |||||
| expected_message = ("If shallow structure is a sequence, input must also " | |||||
| "be a sequence. Input has type: <(type|class) 'str'>.") | |||||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| input_tree = "input_tree" | |||||
| shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| # Using non-iterable elements. | |||||
| input_tree = 0 | |||||
| shallow_tree = [9] | |||||
| expected_message = ("If shallow structure is a sequence, input must also " | |||||
| "be a sequence. Input has type: <(type|class) 'int'>.") | |||||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| input_tree = 0 | |||||
| shallow_tree = [9, 8] | |||||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||||
| def testMapStructureUpTo(self): | |||||
| # Named tuples. | |||||
| ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||||
| op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||||
| inp_val = ab_tuple(a=2, b=3) | |||||
| inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) | |||||
| out = nest.map_structure_up_to( | |||||
| inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||||
| self.assertEqual(out.a, 6) | |||||
| self.assertEqual(out.b, 15) | |||||
| # Lists. | |||||
| data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||||
| name_list = ["evens", ["odds", "primes"]] | |||||
| out = nest.map_structure_up_to( | |||||
| name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||||
| name_list, data_list) | |||||
| self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||||
| # Dicts. | |||||
| inp_val = dict(a=2, b=3) | |||||
| inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||||
| out = nest.map_structure_up_to( | |||||
| inp_val, | |||||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| self.assertEqual(out["a"], 6) | |||||
| self.assertEqual(out["b"], 15) | |||||
| # Non-equal dicts. | |||||
| inp_val = dict(a=2, b=3) | |||||
| inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||||
| with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
| nest.map_structure_up_to( | |||||
| inp_val, | |||||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| # Dict+custom mapping. | |||||
| inp_val = dict(a=2, b=3) | |||||
| inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||||
| out = nest.map_structure_up_to( | |||||
| inp_val, | |||||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| self.assertEqual(out["a"], 6) | |||||
| self.assertEqual(out["b"], 15) | |||||
| # Non-equal dict/mapping. | |||||
| inp_val = dict(a=2, b=3) | |||||
| inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||||
| with self.assertRaisesRegexp(ValueError, "same keys"): | |||||
| nest.map_structure_up_to( | |||||
| inp_val, | |||||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||||
| def testGetTraverseShallowStructure(self): | |||||
| scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||||
| scalar_traverse_r = nest.get_traverse_shallow_structure( | |||||
| lambda s: not isinstance(s, tuple), | |||||
| scalar_traverse_input) | |||||
| self.assertEqual(scalar_traverse_r, | |||||
| [True, True, False, [True, True], {"a": False}, []]) | |||||
| nest.assert_shallow_structure(scalar_traverse_r, | |||||
| scalar_traverse_input) | |||||
| structure_traverse_input = [(1, [2]), ([1], 2)] | |||||
| structure_traverse_r = nest.get_traverse_shallow_structure( | |||||
| lambda s: (True, False) if isinstance(s, tuple) else True, | |||||
| structure_traverse_input) | |||||
| self.assertEqual(structure_traverse_r, | |||||
| [(True, False), ([True], False)]) | |||||
| nest.assert_shallow_structure(structure_traverse_r, | |||||
| structure_traverse_input) | |||||
| with self.assertRaisesRegexp(TypeError, "returned structure"): | |||||
| nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||||
| with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||||
| nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||||
| with self.assertRaisesRegexp( | |||||
| TypeError, "didn't return a depth=1 structure of bools"): | |||||
| nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||||
| def testYieldFlatStringPaths(self): | |||||
| for inputs_expected in ({"inputs": [], "expected": []}, | |||||
| {"inputs": 3, "expected": [()]}, | |||||
| {"inputs": [3], "expected": [(0,)]}, | |||||
| {"inputs": {"a": 3}, "expected": [("a",)]}, | |||||
| {"inputs": {"a": {"b": 4}}, | |||||
| "expected": [("a", "b")]}, | |||||
| {"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||||
| {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||||
| {"inputs": [{"a": [(23, 42)]}], | |||||
| "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||||
| {"inputs": [{"a": ([23], 42)}], | |||||
| "expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||||
| {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||||
| "expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||||
| {"inputs": {"0": [{"1": 23}]}, | |||||
| "expected": [("0", 0, "1")]}): | |||||
| inputs = inputs_expected["inputs"] | |||||
| expected = inputs_expected["expected"] | |||||
| self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||||
| def testFlattenWithStringPaths(self): | |||||
| for inputs_expected in ( | |||||
| {"inputs": [], "expected": []}, | |||||
| {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||||
| {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||||
| inputs = inputs_expected["inputs"] | |||||
| expected = inputs_expected["expected"] | |||||
| self.assertEqual( | |||||
| nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||||
| expected) | |||||
| # Need a separate test for namedtuple as we can't declare tuple definitions | |||||
| # in the @parameterized arguments. | |||||
| def testFlattenNamedTuple(self): | |||||
| # pylint: disable=invalid-name | |||||
| Foo = collections.namedtuple("Foo", ["a", "b"]) | |||||
| Bar = collections.namedtuple("Bar", ["c", "d"]) | |||||
| # pylint: enable=invalid-name | |||||
| test_cases = [ | |||||
| (Foo(a=3, b=Bar(c=23, d=42)), | |||||
| [("a", 3), ("b/c", 23), ("b/d", 42)]), | |||||
| (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")), | |||||
| [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||||
| (Bar(c=42, d=43), | |||||
| [("c", 42), ("d", 43)]), | |||||
| (Bar(c=[42], d=43), | |||||
| [("c/0", 42), ("d", 43)]), | |||||
| ] | |||||
| for inputs, expected in test_cases: | |||||
| self.assertEqual( | |||||
| list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||||
| @parameterized.named_parameters( | |||||
| ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||||
| ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||||
| {"a": ("a", 4), "b": ("b", 6)}), | |||||
| ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||||
| ("nested", | |||||
| {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||||
| {"a": [("a/0", 10), ("a/1", 12)], | |||||
| "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||||
| def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||||
| def format_sum(path, *values): | |||||
| return (path, sum(values)) | |||||
| result = nest.map_structure_with_paths(format_sum, s1, s2, | |||||
| check_types=check_types) | |||||
| self.assertEqual(expected, result) | |||||
| @parameterized.named_parameters( | |||||
| ("tuples", (1, 2), (3, 4, 5), ValueError), | |||||
| ("dicts", {"a": 1}, {"b": 2}, ValueError), | |||||
| ("mixed", (1, 2), [3, 4], TypeError), | |||||
| ("nested", | |||||
| {"a": [2, 3], "b": [1, 3]}, | |||||
| {"b": [5, 6, 7], "a": [8, 9]}, | |||||
| ValueError | |||||
| )) | |||||
| def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||||
| with self.assertRaises(error_type): | |||||
| nest.map_structure_with_paths(lambda path, *s: 0, s1, s2) | |||||
| class NestBenchmark(test.Benchmark): | |||||
| def run_and_report(self, s1, s2, name): | |||||
| burn_iter, test_iter = 100, 30000 | |||||
| for _ in xrange(burn_iter): | |||||
| nest.assert_same_structure(s1, s2) | |||||
| t0 = time.time() | |||||
| for _ in xrange(test_iter): | |||||
| nest.assert_same_structure(s1, s2) | |||||
| t1 = time.time() | |||||
| self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||||
| name=name) | |||||
| def benchmark_assert_structure(self): | |||||
| s1 = (((1, 2), 3), 4, (5, 6)) | |||||
| s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||||
| self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||||
| s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||||
| s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||||
| self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||||
| if __name__ == "__main__": | |||||
| test.main() | |||||