more cond tests and parts of nest.pytags/v0.9
| @@ -27,10 +27,17 @@ namespace Tensorflow | |||
| return op.type == "Switch" || op.type == "RefSwitch"; | |||
| } | |||
| /// <summary> | |||
| /// Return the control flow context for the output of an op. | |||
| /// </summary> | |||
| public static IControlFlowContext GetOutputContext(Operation op) | |||
| { | |||
| var ctxt = op._get_control_flow_context(); | |||
| // Exit nodes usually have a control flow context, except in the case where the | |||
| // exit node was imported via import_graph_def (in which case no nodes have | |||
| // control flow contexts). | |||
| if (ctxt != null && IsLoopExit(op)) | |||
| ctxt = ctxt.outer_context; | |||
| return ctxt; | |||
| } | |||
| } | |||
| @@ -36,6 +36,7 @@ namespace Tensorflow | |||
| return instance; | |||
| } | |||
| [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | |||
| public static void with(IPython py, Action<IPython> action) | |||
| { | |||
| try | |||
| @@ -55,6 +56,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | |||
| public static void with<T>(T py, Action<T> action) where T : IPython | |||
| { | |||
| try | |||
| @@ -74,6 +76,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| [DebuggerNonUserCode()] // with "Just My Code" enabled this lets the debugger break at the origin of the exception | |||
| public static TOut with<TIn, TOut>(TIn py, Func<TIn, TOut> action) where TIn : IPython | |||
| { | |||
| try | |||
| @@ -0,0 +1,871 @@ | |||
| using System; | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using NumSharp; | |||
| namespace Tensorflow.Util | |||
| { | |||
| //Functions for working with arbitrarily nested sequences of elements. | |||
| //This module can perform operations on nested structures. A nested structure is a | |||
| //Python sequence, tuple (including `namedtuple`), or dict that can contain | |||
| //further sequences, tuples, and dicts. | |||
| //The utilities here assume (and do not check) that the nested structures form a | |||
| //'tree', i.e., no references in the structure of the input of these functions | |||
| //should be recursive. | |||
| //Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0), | |||
| // (np.array([3, 4]), tf.constant([3, 4])))` | |||
| // | |||
| public static class nest | |||
| { | |||
| //def _get_attrs_values(obj): | |||
| // """Returns the list of values from an attrs instance.""" | |||
| // attrs = getattr(obj.__class__, "__attrs_attrs__") | |||
| // return [getattr(obj, a.name) for a in attrs] | |||
| /// <summary> | |||
| /// Returns a sorted list of the dict keys, with error if keys not sortable. | |||
| /// </summary> | |||
| private static IEnumerable<string> _sorted(IDictionary dict_) | |||
| { | |||
| return dict_.Keys.OfType<string>().OrderBy(x => x); | |||
| } | |||
| //def _is_namedtuple(instance, strict=False): | |||
| // """Returns True iff `instance` is a `namedtuple`. | |||
| // Args: | |||
| // instance: An instance of a Python object. | |||
| // strict: If True, `instance` is considered to be a `namedtuple` only if | |||
| // it is a "plain" namedtuple. For instance, a class inheriting | |||
| // from a `namedtuple` will be considered to be a `namedtuple` | |||
| // iff `strict=False`. | |||
| // Returns: | |||
| // True if `instance` is a `namedtuple`. | |||
| // """ | |||
| // return _pywrap_tensorflow.IsNamedtuple(instance, strict) | |||
| //# See the swig file (util.i) for documentation. | |||
| //_is_mapping = _pywrap_tensorflow.IsMapping | |||
| //_is_attrs = _pywrap_tensorflow.IsAttrs | |||
| /// <summary> | |||
| /// Converts the sequence `args` to the same type as `instance`. | |||
| /// </summary> | |||
| /// <param name="instance">an instance of `tuple`, `list`, `namedtuple`, `dict`, or | |||
| /// `collections.OrderedDict`.</param> | |||
| /// <param name="args">elements to be converted to the `instance` type.</param> | |||
| /// <returns>`args` with the type of `instance`.</returns> | |||
| private static object _sequence_like(object instance, IEnumerable<object> args) | |||
| { | |||
| if (is_mapping(instance)) | |||
| { | |||
| //# Pack dictionaries in a deterministic order by sorting the keys. | |||
| //# Notice this means that we ignore the original order of `OrderedDict` | |||
| //# instances. This is intentional, to avoid potential bugs caused by mixing | |||
| //# ordered and plain dicts (e.g., flattening a dict but using a | |||
| //# corresponding `OrderedDict` to pack it back). | |||
| // result = dict(zip(_sorted(instance), args)) | |||
| // return type(instance)((key, result[key]) for key in _six.iterkeys(instance)) | |||
| } | |||
| //else if( _is_namedtuple(instance) || _is_attrs(instance)) | |||
| // return type(instance)(*args) | |||
| else | |||
| { | |||
| // Not a namedtuple | |||
| switch (instance) | |||
| { | |||
| case object[] array: | |||
| var result_array = new object[args.Count()]; | |||
| int i = 0; | |||
| foreach (var x in args) | |||
| { | |||
| result_array[i] = x; | |||
| i++; | |||
| } | |||
| return result_array; | |||
| case List<object> list: | |||
| return new List<object>(args); | |||
| default: | |||
| throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | |||
| } | |||
| } | |||
| throw new TypeError("Type of sequence not supported (yet): " + instance.GetType()); | |||
| } | |||
| /// <summary> | |||
| /// Yields the next value from the given iterable. | |||
| /// </summary> | |||
| private static IEnumerable<object> _yield_value(object iterable) | |||
| { | |||
| if (is_mapping(iterable)) | |||
| { | |||
| var dict = iterable as IDictionary; | |||
| //# Iterate through dictionaries in a deterministic order by sorting the | |||
| //# keys. Notice this means that we ignore the original order of `OrderedDict` | |||
| //# instances. This is intentional, to avoid potential bugs caused by mixing | |||
| //# ordered and plain dicts (e.g., flattening a dict but using a | |||
| //# corresponding `OrderedDict` to pack it back). | |||
| foreach (var key in _sorted(dict)) | |||
| yield return dict[key]; | |||
| } | |||
| //else if (_is_attrs(iterable)) | |||
| //{ | |||
| // // for value in _get_attrs_values(iterable): | |||
| // // yield value | |||
| //} | |||
| else if (iterable is IEnumerable) | |||
| { | |||
| var enumerable = iterable as IEnumerable; | |||
| foreach (var value in enumerable) | |||
| yield return value; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError("Unexpected iterable type: " + iterable.GetType()); | |||
| //var jobj = JObject.FromObject(iterable); | |||
| //foreach (var key in _sorted()) | |||
| // yield return jobj[key]; | |||
| } | |||
| } | |||
| //# See the swig file (util.i) for documentation. | |||
| public static bool is_sequence(object arg) => arg is IEnumerable && !(arg is string); | |||
| public static bool is_mapping(object arg) => arg is IDictionary; | |||
| //# See the swig file (util.i) for documentation. | |||
| //flatten = _pywrap_tensorflow.Flatten | |||
| public static List<object> flatten(object structure) | |||
| { | |||
| var list = new List<object>(); | |||
| _flatten_recursive(structure, list); | |||
| return list; | |||
| } | |||
| private static void _flatten_recursive(object obj, List<object> list) | |||
| { | |||
| if (obj is string) | |||
| { | |||
| list.Add(obj); | |||
| return; | |||
| } | |||
| if (obj is IDictionary) | |||
| { | |||
| var dict = obj as IDictionary; | |||
| foreach (var key in _sorted(dict)) | |||
| _flatten_recursive(dict[key], list); | |||
| return; | |||
| } | |||
| if (obj is NDArray) | |||
| { | |||
| list.Add(obj); | |||
| return; | |||
| } | |||
| if (obj is IEnumerable) | |||
| { | |||
| var structure = obj as IEnumerable; | |||
| foreach (var child in structure) | |||
| _flatten_recursive(child, list); | |||
| return; | |||
| } | |||
| list.Add(obj); | |||
| } | |||
| //# See the swig file (util.i) for documentation. | |||
| //_same_namedtuples = _pywrap_tensorflow.SameNamedtuples | |||
| //class _DotString(object): | |||
| // def __str__(self): | |||
| // return "." | |||
| // def __repr__(self): | |||
| // return "." | |||
| //_DOT = _DotString() | |||
| //def assert_same_structure(nest1, nest2, check_types=True): | |||
| // """Asserts that two structures are nested in the same way. | |||
| // Note that namedtuples with identical name and fields are always considered | |||
| // to have the same shallow structure (even with `check_types=True`). | |||
| // For intance, this code will print `True`: | |||
| // ```python | |||
| // def nt(a, b): | |||
| // return collections.namedtuple('foo', 'a b')(a, b) | |||
| // print(assert_same_structure(nt(0, 1), nt(2, 3))) | |||
| // ``` | |||
| // Args: | |||
| // nest1: an arbitrarily nested structure. | |||
| // nest2: an arbitrarily nested structure. | |||
| // check_types: if `True` (default) types of sequences are checked as well, | |||
| // including the keys of dictionaries. If set to `False`, for example a | |||
| // list and a tuple of objects will look the same if they have the same | |||
| // size. Note that namedtuples with identical name and fields are always | |||
| // considered to have the same shallow structure. Two types will also be | |||
| // considered the same if they are both list subtypes (which allows "list" | |||
| // and "_ListWrapper" from checkpointable dependency tracking to compare | |||
| // equal). | |||
| // Raises: | |||
| // ValueError: If the two structures do not have the same number of elements or | |||
| // if the two structures are not nested in the same way. | |||
| // TypeError: If the two structures differ in the type of sequence in any of | |||
| // their substructures. Only possible if `check_types` is `True`. | |||
| // """ | |||
| // try: | |||
| // _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types) | |||
| // except (ValueError, TypeError) as e: | |||
| // str1 = str(map_structure(lambda _: _DOT, nest1)) | |||
| // str2 = str(map_structure(lambda _: _DOT, nest2)) | |||
| // raise type(e)("%s\n" | |||
| // "Entire first structure:\n%s\n" | |||
| // "Entire second structure:\n%s" | |||
| // % (str(e), str1, str2)) | |||
| //def flatten_dict_items(dictionary): | |||
| // """Returns a dictionary with flattened keys and values. | |||
| // This function flattens the keys and values of a dictionary, which can be | |||
| // arbitrarily nested structures, and returns the flattened version of such | |||
| // structures: | |||
| // ```python | |||
| // example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))} | |||
| // result = {4: "a", 5: "b", 6: "c", 8: "d"} | |||
| // flatten_dict_items(example_dictionary) == result | |||
| // ``` | |||
| // The input dictionary must satisfy two properties: | |||
| // 1. Its keys and values should have the same exact nested structure. | |||
| // 2. The set of all flattened keys of the dictionary must not contain repeated | |||
| // keys. | |||
| // Args: | |||
| // dictionary: the dictionary to zip | |||
| // Returns: | |||
| // The zipped dictionary. | |||
| // Raises: | |||
| // TypeError: If the input is not a dictionary. | |||
| // ValueError: If any key and value have not the same structure, or if keys are | |||
| // not unique. | |||
| // """ | |||
| // if not isinstance(dictionary, (dict, _collections.Mapping)): | |||
| // raise TypeError("input must be a dictionary") | |||
| // flat_dictionary = {} | |||
| // for i, v in _six.iteritems(dictionary): | |||
| // if not is_sequence(i): | |||
| // if i in flat_dictionary: | |||
| // raise ValueError( | |||
| // "Could not flatten dictionary: key %s is not unique." % i) | |||
| // flat_dictionary[i] = v | |||
| // else: | |||
| // flat_i = flatten(i) | |||
| // flat_v = flatten(v) | |||
| // if len(flat_i) != len(flat_v): | |||
| // raise ValueError( | |||
| // "Could not flatten dictionary. Key had %d elements, but value had " | |||
| // "%d elements. Key: %s, value: %s." | |||
| // % (len(flat_i), len(flat_v), flat_i, flat_v)) | |||
| // for new_i, new_v in zip(flat_i, flat_v): | |||
| // if new_i in flat_dictionary: | |||
| // raise ValueError( | |||
| // "Could not flatten dictionary: key %s is not unique." | |||
| // % (new_i)) | |||
| // flat_dictionary[new_i] = new_v | |||
| // return flat_dictionary | |||
| /// <summary> | |||
| /// Helper function for pack_sequence_as. | |||
| /// </summary> | |||
| /// <param name="structure">Substructure (list / tuple / dict) to mimic.</param> | |||
| /// <param name="flat">Flattened values to output substructure for.</param> | |||
| /// <param name="index">Index at which to start reading from flat.</param> | |||
| /// <returns> | |||
| /// The tuple(new_index, child), where: | |||
| /// * new_index - the updated index into `flat` having processed `structure`. | |||
| /// * packed - the subset of `flat` corresponding to `structure`, | |||
| /// having started at `index`, and packed into the same nested | |||
| /// format.</returns> | |||
| private static (int new_index, List<object> child) _packed_nest_with_indices(object structure, List<object> flat, | |||
| int index) | |||
| { | |||
| var packed = new List<object>(); | |||
| foreach (var s in _yield_value(structure)) | |||
| { | |||
| if (is_sequence(s)) | |||
| { | |||
| var (new_index, child) = _packed_nest_with_indices(s, flat, index); | |||
| packed.Add(_sequence_like(s, child)); | |||
| index = new_index; | |||
| } | |||
| else | |||
| { | |||
| packed.Add(flat[index]); | |||
| index += 1; | |||
| } | |||
| } | |||
| return (index, packed); | |||
| } | |||
| private static int len(IEnumerable<object> x) => x.Count(); | |||
| /// <summary> | |||
| /// Returns a given flattened sequence packed into a given structure. | |||
| /// If `structure` is a scalar, `flat_sequence` must be a single-element list; | |||
| /// in this case the return value is `flat_sequence[0]`. | |||
| /// | |||
| /// If `structure` is or contains a dict instance, the keys will be sorted to | |||
| /// pack the flat sequence in deterministic order. This is true also for | |||
| /// `OrderedDict` instances: their sequence order is ignored, the sorting order of | |||
| /// keys is used instead. The same convention is followed in `flatten`. | |||
| /// This correctly repacks dicts and `OrderedDict`s after they have been | |||
| /// flattened, and also allows flattening an `OrderedDict` and then repacking it | |||
| /// back using a corresponding plain dict, or vice-versa. | |||
| /// Dictionaries with non-sortable keys cannot be flattened. | |||
| /// </summary> | |||
| /// <param name="structure"> | |||
| /// Nested structure, whose structure is given by nested lists, | |||
| /// tuples, and dicts. Note: numpy arrays and strings are considered | |||
| /// scalars. | |||
| /// </param> | |||
| /// <param name="flat_sequence"> flat sequence to pack.</param> | |||
| /// <returns> `flat_sequence` converted to have the same recursive structure as | |||
| /// `structure`. | |||
| /// </returns> | |||
| public static object pack_sequence_as(object structure, List<object> flat_sequence) | |||
| { | |||
| if (flat_sequence == null) | |||
| throw new ArgumentException("flat_sequence must not be null"); | |||
| // if not is_sequence(flat_sequence): | |||
| // raise TypeError("flat_sequence must be a sequence") | |||
| if (!is_sequence(structure)) | |||
| { | |||
| if (len(flat_sequence) != 1) | |||
| throw new ValueError($"Structure is a scalar but len(flat_sequence) == {len(flat_sequence)} > 1"); | |||
| return flat_sequence.FirstOrDefault(); | |||
| } | |||
| int final_index = 0; | |||
| List<object> packed = null; | |||
| try | |||
| { | |||
| (final_index, packed) = _packed_nest_with_indices(structure, flat_sequence, 0); | |||
| if (final_index < len(flat_sequence)) | |||
| throw new IndexOutOfRangeException($"Final index: { final_index} was smaller than len(flat_sequence): { len(flat_sequence) }"); | |||
| } | |||
| catch (IndexOutOfRangeException) | |||
| { | |||
| var flat_structure = flatten(structure); | |||
| if (len(flat_structure) != len(flat_sequence)) | |||
| { | |||
| throw new ValueError("Could not pack sequence. Structure had %d elements, but " + | |||
| $"flat_sequence had {len(flat_structure)} elements. flat_sequence had: {len(flat_sequence)}"); | |||
| } | |||
| return _sequence_like(structure, packed); | |||
| } | |||
| return packed; | |||
| } | |||
| /// <summary> | |||
| /// Applies `func` to each entry in `structure` and returns a new structure. | |||
| /// | |||
| /// Applies `func(x[0], x[1], ...)` where x[i] is an entry in | |||
| /// `structure[i]`. All structures in `structure` must have the same arity, | |||
| /// and the return value will contain the results in the same structure. | |||
| /// </summary> | |||
| /// <typeparam name="T"></typeparam> | |||
| /// <typeparam name="U"></typeparam> | |||
| /// <param name="func"> A callable that accepts as many arguments as there are structures.</param> | |||
| /// <param name="structure">scalar, or tuple or list of constructed scalars and/or other | |||
| /// tuples/lists, or scalars. Note: numpy arrays are considered as scalars.</param> | |||
| /// <param name="check_types">If set to | |||
| /// `True` (default) the types of iterables within the structures have to be | |||
| /// same (e.g. `map_structure(func, [1], (1,))` raises a `TypeError` | |||
| /// exception). To allow this set this argument to `False`. | |||
| /// Note that namedtuples with identical name and fields are always | |||
| /// considered to have the same shallow structure.</param> | |||
| /// <returns> | |||
| /// A new structure with the same arity as `structure`, whose values correspond | |||
| /// to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding | |||
| /// location in `structure[i]`. If there are different sequence types and | |||
| /// `check_types` is `False` the sequence types of the first structure will be | |||
| /// used. | |||
| /// </returns> | |||
| public static IEnumerable<U> map_structure<T, U>(Func<T, U> func, IEnumerable<T> structure, bool check_types = false) | |||
| { | |||
| // for other in structure[1:]: | |||
| // assert_same_structure(structure[0], other, check_types=check_types) | |||
| // flat_structure = [flatten(s) for s in structure] | |||
| // entries = zip(*flat_structure) | |||
| // return pack_sequence_as( | |||
| // structure[0], [func(*x) for x in entries]) | |||
| return null; | |||
| } | |||
| //def map_structure_with_paths(func, *structure, **kwargs): | |||
| // """Applies `func` to each entry in `structure` and returns a new structure. | |||
| // Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in | |||
| // `structure[i]` and `path` is the common path to x[i] in the structures. All | |||
| // structures in `structure` must have the same arity, and the return value will | |||
| // contain the results in the same structure. Special kwarg `check_types` | |||
| // determines whether the types of iterables within the structure must be the | |||
| // same-- see **kwargs definition below. | |||
| // Args: | |||
| // func: A callable with the signature func(path, *values, **kwargs) that is | |||
| // evaluated on the leaves of the structure. | |||
| // *structure: A variable number of compatible structures to process. | |||
| // **kwargs: Optional kwargs to be passed through to func. Special kwarg | |||
| // `check_types` is not passed to func, but instead determines whether the | |||
| // types of iterables within the structures have to be same (e.g., | |||
| // `map_structure(func, [1], (1,))` raises a `TypeError` exception). By | |||
| // default, the types must match. To allow iteration over structures of | |||
| // different types (but common arity), set this kwarg to `False`. | |||
| // Returns: | |||
| // A structure of the same form as the input structures whose leaves are the | |||
| // result of evaluating func on corresponding leaves of the input structures. | |||
| // Raises: | |||
| // TypeError: If `func` is not callable or if the structures do not match | |||
| // each other by depth tree. | |||
| // TypeError: If `check_types` is not `False` and the two structures differ in | |||
| // the type of sequence in any of their substructures. | |||
| // ValueError: If no structures are provided. | |||
| // """ | |||
| // if not callable(func): | |||
| // raise TypeError("func must be callable, got: %s" % func) | |||
| // if not structure: | |||
| // raise ValueError("Must provide at least one structure") | |||
| // check_types = kwargs.pop("check_types", True) | |||
| // for other in structure[1:]: | |||
| // assert_same_structure(structure[0], other, check_types=check_types) | |||
| //# First set paths_and_values to: | |||
| //# [[(p11, v11), ... (p1n, v1n)], ... [(pm1, vm1), ... (pmn, vmn)]] | |||
| // paths_and_values = [flatten_with_joined_string_paths(s) for s in structure] | |||
| //# Now zip(*paths_and_values) would be: | |||
| //# [((p11, v11), ... (pm1, vm1)), ... ((p1n, v1n), ... (pmn, vmn))] | |||
| //# so grouped_by_path is set to: | |||
| //# [[(p11, ... pm1), (v11, ... vm1)], ... [(p1n, ... pmn), (v1n, ... vmn)]] | |||
| //# Note that p1i, ... pmi must all be equal since the structures are the same. | |||
| // grouped_by_path = [zip(*p_v) for p_v in zip(*paths_and_values)] | |||
| // return pack_sequence_as(structure[0], [ | |||
| // func(paths[0], *values, **kwargs) for paths, values in grouped_by_path]) | |||
| //def _yield_flat_up_to(shallow_tree, input_tree): | |||
| // """Yields elements `input_tree` partially flattened up to `shallow_tree`.""" | |||
| // if is_sequence(shallow_tree): | |||
| // for shallow_branch, input_branch in zip(_yield_value(shallow_tree), | |||
| // _yield_value(input_tree)): | |||
| // for input_leaf in _yield_flat_up_to(shallow_branch, input_branch): | |||
| // yield input_leaf | |||
| // else: | |||
| // yield input_tree | |||
| //def assert_shallow_structure(shallow_tree, input_tree, check_types=True): | |||
| // """Asserts that `shallow_tree` is a shallow structure of `input_tree`. | |||
| // That is, this function tests if the `input_tree` structure can be created from | |||
| // the `shallow_tree` structure by replacing its leaf nodes with deeper | |||
| // tree structures. | |||
| // Examples: | |||
| // The following code will raise an exception: | |||
| // ```python | |||
| // shallow_tree = ["a", "b"] | |||
| // input_tree = ["c", ["d", "e"], "f"] | |||
| // assert_shallow_structure(shallow_tree, input_tree) | |||
| // ``` | |||
| // The following code will not raise an exception: | |||
| // ```python | |||
| // shallow_tree = ["a", "b"] | |||
| // input_tree = ["c", ["d", "e"]] | |||
| // assert_shallow_structure(shallow_tree, input_tree) | |||
| // ``` | |||
| // Args: | |||
| // shallow_tree: an arbitrarily nested structure. | |||
| // input_tree: an arbitrarily nested structure. | |||
| // check_types: if `True` (default) the sequence types of `shallow_tree` and | |||
| // `input_tree` have to be the same. Note that even with check_types==True, | |||
| // this function will consider two different namedtuple classes with the same | |||
| // name and _fields attribute to be the same class. | |||
| // Raises: | |||
| // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||
| // TypeError: If the sequence types of `shallow_tree` are different from | |||
| // `input_tree`. Only raised if `check_types` is `True`. | |||
| // ValueError: If the sequence lengths of `shallow_tree` are different from | |||
| // `input_tree`. | |||
| // """ | |||
| // if is_sequence(shallow_tree): | |||
| // if not is_sequence(input_tree): | |||
| // raise TypeError( | |||
| // "If shallow structure is a sequence, input must also be a sequence. " | |||
| // "Input has type: %s." % type(input_tree)) | |||
| // if check_types and not isinstance(input_tree, type(shallow_tree)): | |||
| //# Duck-typing means that nest should be fine with two different | |||
| //# namedtuples with identical name and fields. | |||
| // shallow_is_namedtuple = _is_namedtuple(shallow_tree, False) | |||
| // input_is_namedtuple = _is_namedtuple(input_tree, False) | |||
| // if shallow_is_namedtuple and input_is_namedtuple: | |||
| // if not _same_namedtuples(shallow_tree, input_tree): | |||
| // raise TypeError( | |||
| // "The two namedtuples don't have the same sequence type. Input " | |||
| // "structure has type %s, while shallow structure has type %s." | |||
| // % (type(input_tree), type(shallow_tree))) | |||
| // elif not (isinstance(shallow_tree, _collections.Mapping) | |||
| // and isinstance(input_tree, _collections.Mapping)): | |||
| // raise TypeError( | |||
| // "The two structures don't have the same sequence type. Input " | |||
| // "structure has type %s, while shallow structure has type %s." | |||
| // % (type(input_tree), type(shallow_tree))) | |||
| // if len(input_tree) != len(shallow_tree): | |||
| // raise ValueError( | |||
| // "The two structures don't have the same sequence length. Input " | |||
| // "structure has length %s, while shallow structure has length %s." | |||
| // % (len(input_tree), len(shallow_tree))) | |||
| // if check_types and isinstance(shallow_tree, (dict, _collections.Mapping)): | |||
| // if set(input_tree) != set(shallow_tree): | |||
| // raise ValueError( | |||
| // "The two structures don't have the same keys. Input " | |||
| // "structure has keys %s, while shallow structure has keys %s." % | |||
| // (list(_six.iterkeys(input_tree)), | |||
| // list(_six.iterkeys(shallow_tree)))) | |||
| // input_tree = list(sorted(_six.iteritems(input_tree))) | |||
| // shallow_tree = list(sorted(_six.iteritems(shallow_tree))) | |||
| // for shallow_branch, input_branch in zip(shallow_tree, input_tree): | |||
| // assert_shallow_structure(shallow_branch, input_branch, | |||
| // check_types=check_types) | |||
| //def flatten_up_to(shallow_tree, input_tree): | |||
| // """Flattens `input_tree` up to `shallow_tree`. | |||
| // Any further depth in structure in `input_tree` is retained as elements in the | |||
| // partially flatten output. | |||
| // If `shallow_tree` and `input_tree` are not sequences, this returns a | |||
| // single-element list: `[input_tree]`. | |||
| // Use Case: | |||
| // Sometimes we may wish to partially flatten a nested sequence, retaining some | |||
| // of the nested structure. We achieve this by specifying a shallow structure, | |||
| // `shallow_tree`, we wish to flatten up to. | |||
| // The input, `input_tree`, can be thought of as having the same structure as | |||
| // `shallow_tree`, but with leaf nodes that are themselves tree structures. | |||
| // Examples: | |||
| // ```python | |||
| // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||
| // shallow_tree = [[True, True], [False, True]] | |||
| // flattened_input_tree = flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree) | |||
| //# Output is: | |||
| //# [[2, 2], [3, 3], [4, 9], [5, 5]] | |||
| //# [True, True, False, True] | |||
| // ``` | |||
| // ```python | |||
| // input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]] | |||
| // shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]] | |||
| // input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree) | |||
| // input_tree_flattened = flatten(input_tree) | |||
| //# Output is: | |||
| //# [('a', 1), ('b', 2), ('c', 3), ('d', 4)] | |||
| //# ['a', 1, 'b', 2, 'c', 3, 'd', 4] | |||
| // ``` | |||
| // Non-Sequence Edge Cases: | |||
| // ```python | |||
| // flatten_up_to(0, 0) # Output: [0] | |||
| // flatten_up_to(0, [0, 1, 2]) # Output: [[0, 1, 2]] | |||
| // flatten_up_to([0, 1, 2], 0) # Output: TypeError | |||
| // flatten_up_to([0, 1, 2], [0, 1, 2]) # Output: [0, 1, 2] | |||
| // ``` | |||
| // Args: | |||
| // shallow_tree: a possibly pruned structure of input_tree. | |||
| // input_tree: an arbitrarily nested structure or a scalar object. | |||
| // Note, numpy arrays are considered scalars. | |||
| // Returns: | |||
| // A Python list, the partially flattened version of `input_tree` according to | |||
| // the structure of `shallow_tree`. | |||
| // Raises: | |||
| // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||
| // TypeError: If the sequence types of `shallow_tree` are different from | |||
| // `input_tree`. | |||
| // ValueError: If the sequence lengths of `shallow_tree` are different from | |||
| // `input_tree`. | |||
| // """ | |||
| // assert_shallow_structure(shallow_tree, input_tree) | |||
| // return list(_yield_flat_up_to(shallow_tree, input_tree)) | |||
| //def map_structure_up_to(shallow_tree, func, *inputs): | |||
| // """Applies a function or op to a number of partially flattened inputs. | |||
| // The `inputs` are flattened up to `shallow_tree` before being mapped. | |||
| // Use Case: | |||
| // Sometimes we wish to apply a function to a partially flattened | |||
| // sequence (for example when the function itself takes sequence inputs). We | |||
| // achieve this by specifying a shallow structure, `shallow_tree` we wish to | |||
| // flatten up to. | |||
| // The `inputs`, can be thought of as having the same structure as | |||
| // `shallow_tree`, but with leaf nodes that are themselves tree structures. | |||
| // This function therefore will return something with the same base structure as | |||
| // `shallow_tree`. | |||
| // Examples: | |||
| // ```python | |||
| // ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||
| // op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||
| // inp_val = ab_tuple(a=2, b=3) | |||
| // inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) | |||
| // out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul, | |||
| // inp_val, inp_ops) | |||
| //# Output is: ab_tuple(a=6, b=15) | |||
| // ``` | |||
| // ```python | |||
| // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||
| // name_list = ['evens', ['odds', 'primes']] | |||
| // out = map_structure_up_to( | |||
| // name_list, | |||
| // lambda name, sec: "first_{}_{}".format(len(sec), name), | |||
| // name_list, data_list) | |||
| //# Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']] | |||
| // ``` | |||
| // Args: | |||
| // shallow_tree: a shallow tree, common to all the inputs. | |||
| // func: callable which will be applied to each input individually. | |||
| // *inputs: arbitrarily nested combination of objects that are compatible with | |||
| // shallow_tree. The function `func` is applied to corresponding | |||
| // partially flattened elements of each input, so the function must support | |||
| // arity of `len(inputs)`. | |||
| // Raises: | |||
| // TypeError: If `shallow_tree` is a sequence but `input_tree` is not. | |||
| // TypeError: If the sequence types of `shallow_tree` are different from | |||
| // `input_tree`. | |||
| // ValueError: If the sequence lengths of `shallow_tree` are different from | |||
| // `input_tree`. | |||
| // Returns: | |||
| // result of repeatedly applying `func`, with same structure as | |||
| // `shallow_tree`. | |||
| // """ | |||
| // if not inputs: | |||
| // raise ValueError("Cannot map over no sequences") | |||
| // for input_tree in inputs: | |||
| // assert_shallow_structure(shallow_tree, input_tree) | |||
| //# Flatten each input separately, apply the function to corresponding elements, | |||
| //# then repack based on the structure of the first input. | |||
| // all_flattened_up_to = [flatten_up_to(shallow_tree, input_tree) | |||
| // for input_tree in inputs] | |||
| // results = [func(*tensors) for tensors in zip(*all_flattened_up_to)] | |||
| // return pack_sequence_as(structure=shallow_tree, flat_sequence=results) | |||
| //def get_traverse_shallow_structure(traverse_fn, structure): | |||
| // """Generates a shallow structure from a `traverse_fn` and `structure`. | |||
| // `traverse_fn` must accept any possible subtree of `structure` and return | |||
| // a depth=1 structure containing `True` or `False` values, describing which | |||
| // of the top-level subtrees may be traversed. It may also | |||
| // return scalar `True` or `False` "traversal is OK / not OK for all subtrees." | |||
| // Examples are available in the unit tests (nest_test.py). | |||
| // Args: | |||
| // traverse_fn: Function taking a substructure and returning either a scalar | |||
| // `bool` (whether to traverse that substructure or not) or a depth=1 | |||
| // shallow structure of the same type, describing which parts of the | |||
| // substructure to traverse. | |||
| // structure: The structure to traverse. | |||
| // Returns: | |||
| // A shallow structure containing python bools, which can be passed to | |||
| // `map_structure_up_to` and `flatten_up_to`. | |||
| // Raises: | |||
| // TypeError: if `traverse_fn` returns a sequence for a non-sequence input, | |||
| // or a structure with depth higher than 1 for a sequence input, | |||
| // or if any leaf values in the returned structure or scalar are not type | |||
| // `bool`. | |||
| // """ | |||
| // to_traverse = traverse_fn(structure) | |||
| // if not is_sequence(structure): | |||
| // if not isinstance(to_traverse, bool): | |||
| // raise TypeError("traverse_fn returned structure: %s for non-structure: %s" | |||
| // % (to_traverse, structure)) | |||
| // return to_traverse | |||
| // level_traverse = [] | |||
| // if isinstance(to_traverse, bool): | |||
| // if not to_traverse: | |||
| //# Do not traverse this substructure at all. Exit early. | |||
| // return False | |||
| // else: | |||
| //# Traverse the entire substructure. | |||
| // for branch in _yield_value(structure): | |||
| // level_traverse.append( | |||
| // get_traverse_shallow_structure(traverse_fn, branch)) | |||
| // elif not is_sequence(to_traverse): | |||
| // raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s" | |||
| // % (to_traverse, structure)) | |||
| // else: | |||
| //# Traverse some subset of this substructure. | |||
| // assert_shallow_structure(to_traverse, structure) | |||
| // for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)): | |||
| // if not isinstance(t, bool): | |||
| // raise TypeError( | |||
| // "traverse_fn didn't return a depth=1 structure of bools. saw: %s " | |||
| // " for structure: %s" % (to_traverse, structure)) | |||
| // if t: | |||
| // level_traverse.append( | |||
| // get_traverse_shallow_structure(traverse_fn, branch)) | |||
| // else: | |||
| // level_traverse.append(False) | |||
| // return _sequence_like(structure, level_traverse) | |||
| //def yield_flat_paths(nest): | |||
| // """Yields paths for some nested structure. | |||
| // Paths are lists of objects which can be str-converted, which may include | |||
| // integers or other types which are used as indices in a dict. | |||
| // The flat list will be in the corresponding order as if you called | |||
| // `snt.nest.flatten` on the structure. This is handy for naming Tensors such | |||
| // the TF scope structure matches the tuple structure. | |||
| // E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))` | |||
| // ```shell | |||
| // >>> nest.flatten(value) | |||
| // [3, 23, 42] | |||
| // >>> list(nest.yield_flat_paths(value)) | |||
| // [('a',), ('b', 'c'), ('b', 'd')] | |||
| // ``` | |||
| // ```shell | |||
| // >>> list(nest.yield_flat_paths({'a': [3]})) | |||
| // [('a', 0)] | |||
| // >>> list(nest.yield_flat_paths({'a': 3})) | |||
| // [('a',)] | |||
| // ``` | |||
| // Args: | |||
| // nest: the value to produce a flattened paths list for. | |||
| // Yields: | |||
| // Tuples containing index or key values which form the path to a specific | |||
| // leaf value in the nested structure. | |||
| // """ | |||
| //# The _maybe_add_final_path_element function is used below in order to avoid | |||
| //# adding trailing slashes when the sub-element recursed into is a leaf. | |||
| // if isinstance(nest, (dict, _collections.Mapping)): | |||
| // for key in _sorted(nest): | |||
| // value = nest[key] | |||
| // for sub_path in yield_flat_paths(value): | |||
| // yield (key,) + sub_path | |||
| // elif _is_namedtuple(nest): | |||
| // for key in nest._fields: | |||
| // value = getattr(nest, key) | |||
| // for sub_path in yield_flat_paths(value): | |||
| // yield (key,) + sub_path | |||
| // elif isinstance(nest, _six.string_types): | |||
| // yield () | |||
| // elif isinstance(nest, _collections.Sequence): | |||
| // for idx, value in enumerate(nest): | |||
| // for sub_path in yield_flat_paths(value): | |||
| // yield (idx,) + sub_path | |||
| // else: | |||
| // yield () | |||
| //def flatten_with_joined_string_paths(structure, separator="/"): | |||
| // """Returns a list of (string path, data element) tuples. | |||
| // The order of tuples produced matches that of `nest.flatten`. This allows you | |||
| // to flatten a nested structure while keeping information about where in the | |||
| // structure each data element was located. See `nest.yield_flat_paths` | |||
| // for more information. | |||
| // Args: | |||
| // structure: the nested structure to flatten. | |||
| // separator: string to separate levels of hierarchy in the results, defaults | |||
| // to '/'. | |||
| // Returns: | |||
| // A list of (string, data element) tuples. | |||
| // """ | |||
| // flat_paths = yield_flat_paths(structure) | |||
| // def stringify_and_join(path_elements): | |||
| // return separator.join(str(path_element) for path_element in path_elements) | |||
| // flat_string_paths = [stringify_and_join(path) for path in flat_paths] | |||
| // return list(zip(flat_string_paths, flatten(structure))) | |||
| } | |||
| } | |||
| @@ -5,6 +5,7 @@ using System.Linq; | |||
| using System.Text; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| using Tensorflow.Util; | |||
| namespace TensorFlowNET.UnitTest | |||
| { | |||
| @@ -13,6 +14,15 @@ namespace TensorFlowNET.UnitTest | |||
| /// </summary> | |||
| public class PythonTest : Python | |||
| { | |||
| #region python compatibility layer | |||
| protected PythonTest self { get => this; } | |||
| protected object None { | |||
| get { return null; } | |||
| } | |||
| #endregion | |||
| #region pytest assertions | |||
| public void assertItemsEqual(ICollection given, ICollection expected) | |||
| { | |||
| Assert.IsNotNull(expected); | |||
| @@ -20,20 +30,62 @@ namespace TensorFlowNET.UnitTest | |||
| var e = expected.OfType<object>().ToArray(); | |||
| var g = given.OfType<object>().ToArray(); | |||
| Assert.AreEqual(e.Length, g.Length, $"The collections differ in length expected {e.Length} but got {g.Length}"); | |||
| for(int i=0; i<e.Length; i++) | |||
| for (int i = 0; i < e.Length; i++) | |||
| Assert.AreEqual(e[i], g[i], $"Items differ at index {i}, expected {e[i]} but got {g[i]}"); | |||
| } | |||
| public void assertEqual(object given, object expected) | |||
| { | |||
| if (given is ICollection && expected is ICollection) | |||
| { | |||
| assertItemsEqual(given as ICollection, expected as ICollection); | |||
| return; | |||
| } | |||
| Assert.AreEqual(expected, given); | |||
| } | |||
| public void assertEquals(object given, object expected) | |||
| { | |||
| assertEqual(given, expected); | |||
| } | |||
| public void assertIsNotNone(object given) | |||
| { | |||
| Assert.IsNotNull(given); | |||
| } | |||
| protected PythonTest self { get => this; } | |||
| #endregion | |||
| #region tensor evaluation | |||
| protected object _eval_helper(Tensor[] tensors) | |||
| { | |||
| if (tensors == null) | |||
| return null; | |||
| //return nest.map_structure(self._eval_tensor, tensors); | |||
| return null; | |||
| } | |||
| //def evaluate(self, tensors) : | |||
| // """Evaluates tensors and returns numpy values. | |||
| // Args: | |||
| // tensors: A Tensor or a nested list/tuple of Tensors. | |||
| // Returns: | |||
| // tensors numpy values. | |||
| // """ | |||
| // if context.executing_eagerly(): | |||
| // return self._eval_helper(tensors) | |||
| // else: | |||
| // sess = ops.get_default_session() | |||
| // if sess is None: | |||
| // with self.test_session() as sess: | |||
| // return sess.run(tensors) | |||
| // else: | |||
| // return sess.run(tensors) | |||
| #endregion | |||
| } | |||
| } | |||
| @@ -0,0 +1,107 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| { | |||
| /// <summary> | |||
| /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | |||
| /// </summary> | |||
| [TestClass] | |||
| public class CondTestCases : PythonTest | |||
| { | |||
| [Ignore("Todo")] | |||
| [TestMethod] | |||
| public void testCondTrue() | |||
| { | |||
| //var x = constant_op.constant(2); | |||
| //var y = constant_op.constant(5); | |||
| // var z = control_flow_ops.cond(math_ops.less(x,y), ()=> math_ops.multiply(x, 17), ()=> math_ops.add(y, 23)) | |||
| //self.assertEquals(self.evaluate(z), 34); | |||
| } | |||
| [Ignore("Todo")] | |||
| [TestMethod] | |||
| public void testCondFalse() | |||
| { | |||
| // def testCondFalse(self): | |||
| // x = constant_op.constant(2) | |||
| // y = constant_op.constant(1) | |||
| // z = control_flow_ops.cond( | |||
| // math_ops.less( | |||
| // x, | |||
| // y), lambda: math_ops.multiply(x, 17), lambda: math_ops.add(y, 23)) | |||
| // self.assertEquals(self.evaluate(z), 24) | |||
| } | |||
| [Ignore("Todo")] | |||
| [TestMethod] | |||
| public void testCondTrueLegacy() | |||
| { | |||
| // def testCondTrueLegacy(self): | |||
| // x = constant_op.constant(2) | |||
| // y = constant_op.constant(5) | |||
| // z = control_flow_ops.cond( | |||
| // math_ops.less(x, y), | |||
| // fn1=lambda: math_ops.multiply(x, 17), | |||
| // fn2=lambda: math_ops.add(y, 23)) | |||
| // self.assertEquals(self.evaluate(z), 34) | |||
| } | |||
| [Ignore("Todo")] | |||
| [TestMethod] | |||
| public void testCondFalseLegacy() | |||
| { | |||
| // def testCondFalseLegacy(self): | |||
| // x = constant_op.constant(2) | |||
| // y = constant_op.constant(1) | |||
| // z = control_flow_ops.cond( | |||
| // math_ops.less(x, y), | |||
| // fn1=lambda: math_ops.multiply(x, 17), | |||
| // fn2=lambda: math_ops.add(y, 23)) | |||
| // self.assertEquals(self.evaluate(z), 24) | |||
| } | |||
| [Ignore("Todo")] | |||
| [TestMethod] | |||
| public void testCondMissingArg1() | |||
| { | |||
| // def testCondMissingArg1(self): | |||
| // x = constant_op.constant(1) | |||
| // with self.assertRaises(TypeError): | |||
| // control_flow_ops.cond(True, false_fn=lambda: x) | |||
| } | |||
| [Ignore("Todo")] | |||
| [TestMethod] | |||
| public void testCondMissingArg2() | |||
| { | |||
| // def testCondMissingArg2(self): | |||
| // x = constant_op.constant(1) | |||
| // with self.assertRaises(TypeError): | |||
| // control_flow_ops.cond(True, lambda: x) | |||
| } | |||
| [Ignore("Todo")] | |||
| [TestMethod] | |||
| public void testCondDuplicateArg1() | |||
| { | |||
| // def testCondDuplicateArg1(self): | |||
| // x = constant_op.constant(1) | |||
| // with self.assertRaises(TypeError): | |||
| // control_flow_ops.cond(True, lambda: x, lambda: x, fn1=lambda: x) | |||
| } | |||
| [Ignore("Todo")] | |||
| [TestMethod] | |||
| public void testCondDuplicateArg2() | |||
| { | |||
| // def testCondDuplicateArg2(self): | |||
| // x = constant_op.constant(1) | |||
| // with self.assertRaises(TypeError): | |||
| // control_flow_ops.cond(True, lambda: x, lambda: x, fn2=lambda: x) | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,23 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| { | |||
| /// <summary> | |||
| /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | |||
| /// </summary> | |||
| [TestClass] | |||
| public class ShapeTestCase : PythonTest | |||
| { | |||
| [TestMethod] | |||
| public void testShape() | |||
| { | |||
| var tensor = constant_op.constant(new[]{1.0, 2.0}); | |||
| self.assertEquals(new int[] {2}, tensor.shape); | |||
| self.assertEquals(new int[] {2}, | |||
| control_flow_ops.with_dependencies(new[] {constant_op.constant(1.0).op}, tensor).shape); | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,162 @@ | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Tensorflow; | |||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| { | |||
| /// <summary> | |||
| /// excerpt of tensorflow/python/framework/ops/control_flow_ops_test.py | |||
| /// </summary> | |||
| [TestClass] | |||
| public class SwitchTestCase : PythonTest | |||
| { | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testResourceReadInLoop() | |||
| { | |||
| //def testResourceReadInLoop(self): | |||
| // embedding_matrix = variable_scope.get_variable( | |||
| // "embedding_matrix", initializer=[[2.0], [3.0]], use_resource=True) | |||
| // | |||
| // def cond(it, _): | |||
| // return it < 5 | |||
| // | |||
| // def body(it, cost): | |||
| // embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) | |||
| // cost += math_ops.reduce_sum(embedding) | |||
| // return it + 1, cost | |||
| // | |||
| // _, cost = control_flow_ops.while_loop( | |||
| // cond, body, [constant_op.constant(0), | |||
| // constant_op.constant(0.0)]) | |||
| // with self.cached_session(): | |||
| // self.evaluate(variables.global_variables_initializer()) | |||
| // self.assertAllEqual(10.0, self.evaluate(cost)) | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testIndexedSlicesGradientInCondInWhileLoop() | |||
| { | |||
| doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: false); | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testIndexedSlicesGradientInCondInWhileLoopResource() | |||
| { | |||
| doTestIndexedSlicesGradientInCondInWhileLoop(use_resource: true); | |||
| } | |||
| private void doTestIndexedSlicesGradientInCondInWhileLoop(bool use_resource= false) | |||
| { | |||
| //def doTestIndexedSlicesGradientInCondInWhileLoop(self, use_resource=False): | |||
| // embedding_matrix = variable_scope.get_variable( | |||
| // "embedding_matrix", [5, 5], | |||
| // initializer=init_ops.random_normal_initializer(), | |||
| // use_resource=use_resource) | |||
| // def cond(it, _): | |||
| // return it < 5 | |||
| // def body(it, cost): | |||
| // embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) | |||
| // cost = control_flow_ops.cond( | |||
| // math_ops.equal(it, 3), lambda: math_ops.square(cost), | |||
| // (lambda: cost + math_ops.reduce_sum(embedding))) | |||
| // return it + 1, cost | |||
| // _, cost = control_flow_ops.while_loop( | |||
| // cond, body, [constant_op.constant(0), | |||
| // constant_op.constant(0.0)]) | |||
| // dynamic_grads = gradients_impl.gradients(cost, [embedding_matrix])[0] | |||
| // dynamic_grads = math_ops.segment_sum(dynamic_grads.values, | |||
| // dynamic_grads.indices) | |||
| // embedding = embedding_ops.embedding_lookup(embedding_matrix, [0]) | |||
| // static = math_ops.square( | |||
| // math_ops.reduce_sum(embedding) + math_ops.reduce_sum(embedding) + | |||
| // math_ops.reduce_sum(embedding)) + math_ops.reduce_sum(embedding) | |||
| // static_grads = gradients_impl.gradients(static, [embedding_matrix])[0] | |||
| // static_grads = math_ops.segment_sum(static_grads.values, | |||
| // static_grads.indices) | |||
| // with self.cached_session(): | |||
| // self.evaluate(variables.global_variables_initializer()) | |||
| // self.assertAllEqual(*self.evaluate([static_grads, dynamic_grads])) | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testIndexedSlicesWithShapeGradientInWhileLoop() | |||
| { | |||
| //@test_util.run_v1_only("b/120545219") | |||
| //def testIndexedSlicesWithShapeGradientInWhileLoop(self): | |||
| // for dtype in [dtypes.float32, dtypes.float64]: | |||
| // with self.cached_session() as sess: | |||
| // num_steps = 9 | |||
| // inputs = array_ops.placeholder(dtype=dtype, shape=[num_steps]) | |||
| // initial_outputs = tensor_array_ops.TensorArray( | |||
| // dtype=dtype, size=num_steps) | |||
| // initial_i = constant_op.constant(0, dtype=dtypes.int32) | |||
| // def cond(i, _): | |||
| // return i < num_steps # pylint: disable=cell-var-from-loop | |||
| // def body(i, outputs): | |||
| // x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop | |||
| // outputs = outputs.write(i, x) | |||
| // return i + 1, outputs | |||
| // _, outputs = control_flow_ops.while_loop(cond, body, | |||
| // [initial_i, initial_outputs]) | |||
| // outputs = math_ops.reduce_sum(outputs.stack()) | |||
| // r = gradients_impl.gradients([outputs], [inputs])[0] | |||
| // grad_wr_inputs = ops.convert_to_tensor(r) | |||
| // o, grad = sess.run([outputs, grad_wr_inputs], | |||
| // feed_dict={inputs: [4, 6, 0, 7, 0, 0, 1, 2, 0]}) | |||
| // self.assertEquals(o, 20) | |||
| // self.assertAllEqual(grad, [1] * num_steps) | |||
| } | |||
| [Ignore("TODO")] | |||
| [TestMethod] | |||
| public void testIndexedSlicesWithDynamicShapeGradientInWhileLoop() | |||
| { | |||
| //@test_util.run_v1_only("b/120545219") | |||
| //def testIndexedSlicesWithDynamicShapeGradientInWhileLoop(self): | |||
| // for dtype in [dtypes.float32, dtypes.float64]: | |||
| // with self.cached_session() as sess: | |||
| // inputs = array_ops.placeholder(dtype=dtype) | |||
| // initial_outputs = tensor_array_ops.TensorArray( | |||
| // dtype=dtype, dynamic_size=True, size=1) | |||
| // initial_i = constant_op.constant(0, dtype=dtypes.int32) | |||
| // def cond(i, _): | |||
| // return i < array_ops.size(inputs) # pylint: disable=cell-var-from-loop | |||
| // def body(i, outputs): | |||
| // x = array_ops.gather(inputs, i) # pylint: disable=cell-var-from-loop | |||
| // outputs = outputs.write(i, x) | |||
| // return i + 1, outputs | |||
| // _, outputs = control_flow_ops.while_loop(cond, body, | |||
| // [initial_i, initial_outputs]) | |||
| // outputs = math_ops.reduce_sum(outputs.stack()) | |||
| // r = gradients_impl.gradients([outputs], [inputs])[0] | |||
| // grad_wr_inputs = ops.convert_to_tensor(r) | |||
| // o, grad = sess.run([outputs, grad_wr_inputs], | |||
| // feed_dict={inputs: [1, 3, 2]}) | |||
| // self.assertEquals(o, 6) | |||
| // self.assertAllEqual(grad, [1] * 3) | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,852 @@ | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | |||
| using Newtonsoft.Json.Linq; | |||
| using Tensorflow; | |||
| using Tensorflow.Util; | |||
| namespace TensorFlowNET.UnitTest.control_flow_ops_test | |||
| { | |||
| /// <summary> | |||
| /// excerpt of tensorflow/python/framework/util/nest_test.py | |||
| /// </summary> | |||
| [TestClass] | |||
| public class NestTest : PythonTest | |||
| { | |||
| public class PointXY | |||
| { | |||
| public double x; | |||
| public double y; | |||
| } | |||
| // if attr: | |||
| // class BadAttr(object): | |||
| // """Class that has a non-iterable __attrs_attrs__.""" | |||
| // __attrs_attrs__ = None | |||
| // @attr.s | |||
| // class SampleAttr(object): | |||
| // field1 = attr.ib() | |||
| // field2 = attr.ib() | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testAttrsFlattenAndPack(self) : | |||
| // if attr is None: | |||
| // self.skipTest("attr module is unavailable.") | |||
| // field_values = [1, 2] | |||
| // sample_attr = NestTest.SampleAttr(* field_values) | |||
| // self.assertFalse(nest._is_attrs(field_values)) | |||
| // self.assertTrue(nest._is_attrs(sample_attr)) | |||
| // flat = nest.flatten(sample_attr) | |||
| // self.assertEqual(field_values, flat) | |||
| // restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||
| // self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||
| // self.assertEqual(restructured_from_flat, sample_attr) | |||
| //# Check that flatten fails if attributes are not iterable | |||
| // with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||
| // flat = nest.flatten(NestTest.BadAttr()) | |||
| [TestMethod] | |||
| public void testFlattenAndPack() | |||
| { | |||
| object structure = new object[] {new object[] {3, 4}, 5, new object[] {6, 7, new object[] {9, 10}, 8}}; | |||
| var flat = new List<object> {"a", "b", "c", "d", "e", "f", "g", "h"}; | |||
| self.assertEqual(nest.flatten(structure), new[] {3, 4, 5, 6, 7, 9, 10, 8}); | |||
| self.assertEqual(JArray.FromObject(nest.pack_sequence_as(structure, flat)).ToString(), | |||
| JArray.FromObject(new object[] {new object[] {"a", "b"}, "c", new object[] {"d", "e", new object[] {"f", "g"}, "h"}}).ToString()); | |||
| structure = new object[] { new Hashtable {["x"] = 4, ["y"] = 2}, new object[] { new object[] { new Hashtable { ["x"] = 1,["y"] = 0}, }, }}; | |||
| flat = new List<object> { 4, 2, 1, 0}; | |||
| self.assertEqual(nest.flatten(structure), flat); | |||
| // restructured_from_flat = nest.pack_sequence_as(structure, flat) | |||
| // self.assertEqual(restructured_from_flat, structure) | |||
| // self.assertEqual(restructured_from_flat[0].x, 4) | |||
| // self.assertEqual(restructured_from_flat[0].y, 2) | |||
| // self.assertEqual(restructured_from_flat[1][0][0].x, 1) | |||
| // self.assertEqual(restructured_from_flat[1][0][0].y, 0) | |||
| // self.assertEqual([5], nest.flatten(5)) | |||
| // self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) | |||
| // self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) | |||
| // self.assertEqual( | |||
| // np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) | |||
| // with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): | |||
| // nest.pack_sequence_as("scalar", [4, 5]) | |||
| // with self.assertRaisesRegexp(TypeError, "flat_sequence"): | |||
| // nest.pack_sequence_as([4, 5], "bad_sequence") | |||
| // with self.assertRaises(ValueError): | |||
| // nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) | |||
| } | |||
| // @parameterized.parameters({"mapping_type": collections.OrderedDict | |||
| // }, | |||
| // {"mapping_type": _CustomMapping | |||
| //}) | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testFlattenDictOrder(self, mapping_type) : | |||
| // """`flatten` orders dicts by key, including OrderedDicts.""" | |||
| // ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||
| // plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||
| // ordered_flat = nest.flatten(ordered) | |||
| // plain_flat = nest.flatten(plain) | |||
| // self.assertEqual([0, 1, 2, 3], ordered_flat) | |||
| // self.assertEqual([0, 1, 2, 3], plain_flat) | |||
| // @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||
| // {"mapping_type": _CustomMapping}) | |||
| // def testPackDictOrder(self, mapping_type): | |||
| // """Packing orders dicts by key, including OrderedDicts.""" | |||
| // custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||
| // plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||
| // seq = [0, 1, 2, 3] | |||
| //custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||
| //plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||
| // self.assertIsInstance(custom_reconstruction, mapping_type) | |||
| // self.assertIsInstance(plain_reconstruction, dict) | |||
| // self.assertEqual( | |||
| // mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||
| // custom_reconstruction) | |||
| // self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||
| // Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testFlattenAndPack_withDicts(self) : | |||
| // # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||
| // mess = [ | |||
| // "z", | |||
| // NestTest.Abc(3, 4), { | |||
| // "d": _CustomMapping({ | |||
| // 41: 4 | |||
| // }), | |||
| // "c": [ | |||
| // 1, | |||
| // collections.OrderedDict([ | |||
| // ("b", 3), | |||
| // ("a", 2), | |||
| // ]), | |||
| // ], | |||
| // "b": 5 | |||
| // }, 17 | |||
| // ] | |||
| // flattened = nest.flatten(mess) | |||
| // self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||
| // structure_of_mess = [ | |||
| // 14, | |||
| // NestTest.Abc("a", True), | |||
| // { | |||
| // "d": _CustomMapping({ | |||
| // 41: 42 | |||
| // }), | |||
| // "c": [ | |||
| // 0, | |||
| // collections.OrderedDict([ | |||
| // ("b", 9), | |||
| // ("a", 8), | |||
| // ]), | |||
| // ], | |||
| // "b": 3 | |||
| // }, | |||
| // "hi everybody", | |||
| // ] | |||
| // unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||
| // self.assertEqual(unflattened, mess) | |||
| // # Check also that the OrderedDict was created, with the correct key order. | |||
| //unflattened_ordered_dict = unflattened[2]["c"][1] | |||
| // self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||
| // self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||
| // unflattened_custom_mapping = unflattened[2]["d"] | |||
| // self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||
| // self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||
| // def testFlatten_numpyIsNotFlattened(self): | |||
| // structure = np.array([1, 2, 3]) | |||
| // flattened = nest.flatten(structure) | |||
| // self.assertEqual(len(flattened), 1) | |||
| // def testFlatten_stringIsNotFlattened(self): | |||
| // structure = "lots of letters" | |||
| // flattened = nest.flatten(structure) | |||
| // self.assertEqual(len(flattened), 1) | |||
| // unflattened = nest.pack_sequence_as("goodbye", flattened) | |||
| // self.assertEqual(structure, unflattened) | |||
| // def testPackSequenceAs_notIterableError(self) : | |||
| // with self.assertRaisesRegexp(TypeError, | |||
| // "flat_sequence must be a sequence"): | |||
| // nest.pack_sequence_as("hi", "bye") | |||
| // def testPackSequenceAs_wrongLengthsError(self): | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // "Structure had 2 elements, but flat_sequence had 3 elements."): | |||
| // nest.pack_sequence_as(["hello", "world"], | |||
| // ["and", "goodbye", "again"]) | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testIsSequence(self): | |||
| // self.assertFalse(nest.is_sequence("1234")) | |||
| // self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) | |||
| // self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) | |||
| // self.assertTrue(nest.is_sequence([])) | |||
| // self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) | |||
| // self.assertFalse(nest.is_sequence(set([1, 2]))) | |||
| // ones = array_ops.ones([2, 3]) | |||
| // self.assertFalse(nest.is_sequence(ones)) | |||
| // self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) | |||
| // self.assertFalse(nest.is_sequence(np.ones((4, 5)))) | |||
| // @parameterized.parameters({"mapping_type": _CustomMapping}, | |||
| // {"mapping_type": dict}) | |||
| // def testFlattenDictItems(self, mapping_type): | |||
| // dictionary = mapping_type({ (4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||
| // flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||
| // self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||
| // with self.assertRaises(TypeError): | |||
| // nest.flatten_dict_items(4) | |||
| // bad_dictionary = mapping_type({ (4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||
| // with self.assertRaisesRegexp(ValueError, "not unique"): | |||
| // nest.flatten_dict_items(bad_dictionary) | |||
| // another_bad_dictionary = mapping_type({ | |||
| // (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||
| // }) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||
| // nest.flatten_dict_items(another_bad_dictionary) | |||
| //# pylint does not correctly recognize these as class names and | |||
| //# suggests to use variable style under_score naming. | |||
| //# pylint: disable=invalid-name | |||
| // Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||
| // Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||
| // SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||
| // SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||
| // SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||
| // SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||
| // SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||
| // NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||
| // # pylint: enable=invalid-name | |||
| // class SameNamedType1(SameNameab): | |||
| // pass | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testAssertSameStructure(self): | |||
| // structure1 = (((1, 2), 3), 4, (5, 6)) | |||
| // structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
| // structure_different_num_elements = ("spam", "eggs") | |||
| // structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||
| // nest.assert_same_structure(structure1, structure2) | |||
| // nest.assert_same_structure("abc", 1.0) | |||
| // nest.assert_same_structure("abc", np.array([0, 1])) | |||
| // nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||
| // "First structure:.*?\n\n" | |||
| // "Second structure:.*\n\n" | |||
| // "More specifically: Substructure " | |||
| // r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||
| // 'substructure "type=str str=spam" is not\n' | |||
| // "Entire first structure:\n" | |||
| // r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||
| // "Entire second structure:\n" | |||
| // r"\(\., \.\)")): | |||
| // nest.assert_same_structure(structure1, structure_different_num_elements) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||
| // "First structure:.*?\n\n" | |||
| // "Second structure:.*\n\n" | |||
| // r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
| // r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||
| // "is not")): | |||
| // nest.assert_same_structure([0, 1], np.array([0, 1])) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("The two structures don't have the same nested structure\\.\n\n" | |||
| // "First structure:.*?\n\n" | |||
| // "Second structure:.*\n\n" | |||
| // r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
| // 'is a sequence, while substructure "type=int str=0" ' | |||
| // "is not")): | |||
| // nest.assert_same_structure(0, [0, 1]) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("don't have the same nested structure\\.\n\n" | |||
| // "First structure: .*?\n\nSecond structure: ")): | |||
| // nest.assert_same_structure(structure1, structure_different_nesting) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||
| // NestTest.Named0ab("a", "b")) | |||
| // nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
| // NestTest.Named0ab("a", "b")) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||
| // NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("don't have the same nested structure\\.\n\n" | |||
| // "First structure: .*?\n\nSecond structure: ")): | |||
| // nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
| // NestTest.Named0ab([3], 4)) | |||
| // with self.assertRaisesRegexp( | |||
| // ValueError, | |||
| // ("don't have the same nested structure\\.\n\n" | |||
| // "First structure: .*?\n\nSecond structure: ")): | |||
| // nest.assert_same_structure([[3], 4], [3, [4]]) | |||
| // structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
| // with self.assertRaisesRegexp(TypeError, | |||
| // "don't have the same sequence type"): | |||
| // nest.assert_same_structure(structure1, structure1_list) | |||
| // nest.assert_same_structure(structure1, structure2, check_types= False) | |||
| // nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||
| // with self.assertRaisesRegexp(ValueError, | |||
| // "don't have the same set of keys"): | |||
| // nest.assert_same_structure({"a": 1}, {"b": 1}) | |||
| // nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||
| // NestTest.SameNameab2(2, 3)) | |||
| // # This assertion is expected to pass: two namedtuples with the same | |||
| // # name and field names are considered to be identical. | |||
| // nest.assert_same_structure( | |||
| // NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||
| // NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||
| // expected_message = "The two structures don't have the same.*" | |||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||
| // nest.assert_same_structure( | |||
| // NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||
| // NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||
| // NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||
| // NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||
| // self.assertRaises(TypeError, nest.assert_same_structure, | |||
| // NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||
| // EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||
| // def testHeterogeneousComparison(self): | |||
| // nest.assert_same_structure({"a": 4}, _CustomMapping(a= 3)) | |||
| // nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testMapStructure(self) : | |||
| // structure1 = (((1, 2), 3), 4, (5, 6)) | |||
| // structure2 = (((7, 8), 9), 10, (11, 12)) | |||
| // structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) | |||
| // nest.assert_same_structure(structure1, structure1_plus1) | |||
| // self.assertAllEqual( | |||
| // [2, 3, 4, 5, 6, 7], | |||
| // nest.flatten(structure1_plus1)) | |||
| // structure1_plus_structure2 = nest.map_structure( | |||
| // lambda x, y: x + y, structure1, structure2) | |||
| // self.assertEqual( | |||
| // (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), | |||
| // structure1_plus_structure2) | |||
| // self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||
| // self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||
| // # Empty structures | |||
| // self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||
| // self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||
| // self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||
| // self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||
| // NestTest.EmptyNT())) | |||
| // # This is checking actual equality of types, empty list != empty tuple | |||
| // self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||
| // with self.assertRaisesRegexp(TypeError, "callable"): | |||
| // nest.map_structure("bad", structure1_plus1) | |||
| // with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||
| // nest.map_structure(lambda x: x) | |||
| // with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||
| // nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| // nest.map_structure(lambda x, y: None, 3, (3,)) | |||
| // with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||
| // structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
| // with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
| // nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||
| // nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||
| // check_types=False) | |||
| // with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| // nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||
| // check_types=False) | |||
| // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
| // nest.map_structure(lambda x: None, structure1, foo="a") | |||
| // with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
| // nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||
| // ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||
| // @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| // def testMapStructureWithStrings(self) : | |||
| // inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||
| // inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||
| // out = nest.map_structure(lambda string, repeats: string* repeats, | |||
| // inp_a, | |||
| // inp_b) | |||
| // self.assertEqual("foofoo", out.a) | |||
| // self.assertEqual("bar", out.b[0]) | |||
| // self.assertEqual("bazbazbaz", out.b[1]) | |||
| // nt = NestTest.ABTuple(a=("something", "something_else"), | |||
| // b="yet another thing") | |||
| // rev_nt = nest.map_structure(lambda x: x[::- 1], nt) | |||
| // # Check the output is the correct structure, and all strings are reversed. | |||
| // nest.assert_same_structure(nt, rev_nt) | |||
| // self.assertEqual(nt.a[0][::- 1], rev_nt.a[0]) | |||
| // self.assertEqual(nt.a[1][::- 1], rev_nt.a[1]) | |||
| // self.assertEqual(nt.b[::- 1], rev_nt.b) | |||
| // @test_util.run_deprecated_v1 | |||
| // def testMapStructureOverPlaceholders(self) : | |||
| // inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
| // array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
| // inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
| // array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
| // output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||
| // nest.assert_same_structure(output, inp_a) | |||
| // self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||
| // self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||
| // feed_dict = { | |||
| // inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||
| // inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||
| // } | |||
| // with self.cached_session() as sess: | |||
| // output_np = sess.run(output, feed_dict=feed_dict) | |||
| // self.assertAllClose(output_np[0], | |||
| // feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||
| // self.assertAllClose(output_np[1], | |||
| // feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||
| // def testAssertShallowStructure(self): | |||
| // inp_ab = ["a", "b"] | |||
| //inp_abc = ["a", "b", "c"] | |||
| //expected_message = ( | |||
| // "The two structures don't have the same sequence length. Input " | |||
| // "structure has length 2, while shallow structure has length 3.") | |||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||
| // nest.assert_shallow_structure(inp_abc, inp_ab) | |||
| // inp_ab1 = [(1, 1), (2, 2)] | |||
| // inp_ab2 = [[1, 1], [2, 2]] | |||
| // expected_message = ( | |||
| // "The two structures don't have the same sequence type. Input structure " | |||
| // "has type <(type|class) 'tuple'>, while shallow structure has type " | |||
| // "<(type|class) 'list'>.") | |||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types= False) | |||
| // inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||
| // inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||
| // expected_message = ( | |||
| // r"The two structures don't have the same keys. Input " | |||
| // r"structure has keys \['c'\], while shallow structure has " | |||
| // r"keys \['d'\].") | |||
| // with self.assertRaisesRegexp(ValueError, expected_message): | |||
| // nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
| // inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||
| // inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||
| // nest.assert_shallow_structure(inp_ab, inp_ba) | |||
| // # This assertion is expected to pass: two namedtuples with the same | |||
| //# name and field names are considered to be identical. | |||
| //inp_shallow = NestTest.SameNameab(1, 2) | |||
| // inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||
| // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||
| // nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||
| // def testFlattenUpTo(self): | |||
| // # Shallow tree ends at scalar. | |||
| // input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||
| // shallow_tree = [[True, True], [False, True]] | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||
| // self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||
| //# Shallow tree ends at string. | |||
| // input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||
| // shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // input_tree_flattened = nest.flatten(input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| // [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||
| // self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||
| // # Make sure dicts are correctly flattened, yielding values, not keys. | |||
| //input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||
| // shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| // [1, { "c": 2}, 3, (4, 5)]) | |||
| // # Namedtuples. | |||
| // ab_tuple = NestTest.ABTuple | |||
| // input_tree = ab_tuple(a =[0, 1], b = 2) | |||
| // shallow_tree = ab_tuple(a= 0, b= 1) | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| // [[0, 1], 2]) | |||
| // # Nested dicts, OrderedDicts and namedtuples. | |||
| // input_tree = collections.OrderedDict( | |||
| // [("a", ab_tuple(a =[0, {"b": 1}], b=2)), | |||
| // ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||
| // shallow_tree = input_tree | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||
| // shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| // [ab_tuple(a =[0, { "b": 1}], b=2), | |||
| // 3, | |||
| // collections.OrderedDict([("f", 4)])]) | |||
| // shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||
| // input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| // input_tree) | |||
| // self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| // [ab_tuple(a =[0, {"b": 1}], b=2), | |||
| // {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||
| // ## Shallow non-list edge-case. | |||
| // # Using iterable elements. | |||
| // input_tree = ["input_tree"] | |||
| //shallow_tree = "shallow_tree" | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // input_tree = ["input_tree_0", "input_tree_1"] | |||
| //shallow_tree = "shallow_tree" | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // # Using non-iterable elements. | |||
| //input_tree = [0] | |||
| //shallow_tree = 9 | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // input_tree = [0, 1] | |||
| //shallow_tree = 9 | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // ## Both non-list edge-case. | |||
| //# Using iterable elements. | |||
| //input_tree = "input_tree" | |||
| // shallow_tree = "shallow_tree" | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // # Using non-iterable elements. | |||
| //input_tree = 0 | |||
| // shallow_tree = 0 | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_input_tree, [input_tree]) | |||
| // self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| // ## Input non-list edge-case. | |||
| //# Using iterable elements. | |||
| //input_tree = "input_tree" | |||
| // shallow_tree = ["shallow_tree"] | |||
| //expected_message = ("If shallow structure is a sequence, input must also " | |||
| // "be a sequence. Input has type: <(type|class) 'str'>.") | |||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| // input_tree = "input_tree" | |||
| // shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||
| //with self.assertRaisesRegexp(TypeError, expected_message): | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| //# Using non-iterable elements. | |||
| // input_tree = 0 | |||
| // shallow_tree = [9] | |||
| //expected_message = ("If shallow structure is a sequence, input must also " | |||
| // "be a sequence. Input has type: <(type|class) 'int'>.") | |||
| // with self.assertRaisesRegexp(TypeError, expected_message): | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| // input_tree = 0 | |||
| // shallow_tree = [9, 8] | |||
| //with self.assertRaisesRegexp(TypeError, expected_message): | |||
| // flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| // flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| // self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| // def testMapStructureUpTo(self) : | |||
| // # Named tuples. | |||
| // ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||
| // op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||
| // inp_val = ab_tuple(a= 2, b= 3) | |||
| // inp_ops = ab_tuple(a= op_tuple(add = 1, mul = 2), b= op_tuple(add = 2, mul = 3)) | |||
| // out = nest.map_structure_up_to( | |||
| // inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||
| // self.assertEqual(out.a, 6) | |||
| // self.assertEqual(out.b, 15) | |||
| // # Lists. | |||
| // data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||
| // name_list = ["evens", ["odds", "primes"]] | |||
| // out = nest.map_structure_up_to( | |||
| // name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||
| // name_list, data_list) | |||
| // self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||
| // # Dicts. | |||
| // inp_val = dict(a= 2, b= 3) | |||
| // inp_ops = dict(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||
| // out = nest.map_structure_up_to( | |||
| // inp_val, | |||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| // self.assertEqual(out["a"], 6) | |||
| // self.assertEqual(out["b"], 15) | |||
| // # Non-equal dicts. | |||
| // inp_val = dict(a= 2, b= 3) | |||
| // inp_ops = dict(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||
| // with self.assertRaisesRegexp(ValueError, "same keys"): | |||
| // nest.map_structure_up_to( | |||
| // inp_val, | |||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| // # Dict+custom mapping. | |||
| // inp_val = dict(a= 2, b= 3) | |||
| // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), b= dict(add = 2, mul = 3)) | |||
| // out = nest.map_structure_up_to( | |||
| // inp_val, | |||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| // self.assertEqual(out["a"], 6) | |||
| // self.assertEqual(out["b"], 15) | |||
| // # Non-equal dict/mapping. | |||
| // inp_val = dict(a= 2, b= 3) | |||
| // inp_ops = _CustomMapping(a= dict(add = 1, mul = 2), c= dict(add = 2, mul = 3)) | |||
| // with self.assertRaisesRegexp(ValueError, "same keys"): | |||
| // nest.map_structure_up_to( | |||
| // inp_val, | |||
| // lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| // def testGetTraverseShallowStructure(self): | |||
| // scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||
| // scalar_traverse_r = nest.get_traverse_shallow_structure( | |||
| // lambda s: not isinstance(s, tuple), | |||
| // scalar_traverse_input) | |||
| // self.assertEqual(scalar_traverse_r, | |||
| // [True, True, False, [True, True], {"a": False}, []]) | |||
| // nest.assert_shallow_structure(scalar_traverse_r, | |||
| // scalar_traverse_input) | |||
| // structure_traverse_input = [(1, [2]), ([1], 2)] | |||
| // structure_traverse_r = nest.get_traverse_shallow_structure( | |||
| // lambda s: (True, False) if isinstance(s, tuple) else True, | |||
| // structure_traverse_input) | |||
| // self.assertEqual(structure_traverse_r, | |||
| // [(True, False), ([True], False)]) | |||
| // nest.assert_shallow_structure(structure_traverse_r, | |||
| // structure_traverse_input) | |||
| // with self.assertRaisesRegexp(TypeError, "returned structure"): | |||
| // nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||
| // with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||
| // nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||
| // with self.assertRaisesRegexp( | |||
| // TypeError, "didn't return a depth=1 structure of bools"): | |||
| // nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||
| // def testYieldFlatStringPaths(self): | |||
| // for inputs_expected in ({"inputs": [], "expected": []}, | |||
| // {"inputs": 3, "expected": [()]}, | |||
| // {"inputs": [3], "expected": [(0,)]}, | |||
| // {"inputs": {"a": 3}, "expected": [("a",)]}, | |||
| // {"inputs": {"a": {"b": 4}}, | |||
| // "expected": [("a", "b")]}, | |||
| // {"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||
| // {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||
| // {"inputs": [{"a": [(23, 42)]}], | |||
| // "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||
| // {"inputs": [{"a": ([23], 42)}], | |||
| // "expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||
| // {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||
| // "expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||
| // {"inputs": {"0": [{"1": 23}]}, | |||
| // "expected": [("0", 0, "1")]}): | |||
| // inputs = inputs_expected["inputs"] | |||
| // expected = inputs_expected["expected"] | |||
| // self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||
| // def testFlattenWithStringPaths(self): | |||
| // for inputs_expected in ( | |||
| // {"inputs": [], "expected": []}, | |||
| // {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||
| // {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||
| // inputs = inputs_expected["inputs"] | |||
| // expected = inputs_expected["expected"] | |||
| // self.assertEqual( | |||
| // nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||
| // expected) | |||
| // # Need a separate test for namedtuple as we can't declare tuple definitions | |||
| // # in the @parameterized arguments. | |||
| // def testFlattenNamedTuple(self): | |||
| // # pylint: disable=invalid-name | |||
| // Foo = collections.namedtuple("Foo", ["a", "b"]) | |||
| // Bar = collections.namedtuple("Bar", ["c", "d"]) | |||
| // # pylint: enable=invalid-name | |||
| // test_cases = [ | |||
| // (Foo(a = 3, b = Bar(c = 23, d = 42)), | |||
| // [("a", 3), ("b/c", 23), ("b/d", 42)]), | |||
| // (Foo(a = Bar(c = 23, d = 42), b = Bar(c = 0, d = "something")), | |||
| // [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||
| // (Bar(c = 42, d = 43), | |||
| // [("c", 42), ("d", 43)]), | |||
| // (Bar(c =[42], d = 43), | |||
| // [("c/0", 42), ("d", 43)]), | |||
| // ] | |||
| // for inputs, expected in test_cases: | |||
| // self.assertEqual( | |||
| // list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||
| // @parameterized.named_parameters( | |||
| // ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||
| // ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||
| // {"a": ("a", 4), "b": ("b", 6)}), | |||
| // ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||
| // ("nested", | |||
| // {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||
| // {"a": [("a/0", 10), ("a/1", 12)], | |||
| // "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||
| // def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||
| // def format_sum(path, * values): | |||
| // return (path, sum(values)) | |||
| // result = nest.map_structure_with_paths(format_sum, s1, s2, | |||
| // check_types=check_types) | |||
| // self.assertEqual(expected, result) | |||
| // @parameterized.named_parameters( | |||
| // ("tuples", (1, 2), (3, 4, 5), ValueError), | |||
| // ("dicts", {"a": 1}, {"b": 2}, ValueError), | |||
| // ("mixed", (1, 2), [3, 4], TypeError), | |||
| // ("nested", | |||
| // {"a": [2, 3], "b": [1, 3]}, | |||
| // {"b": [5, 6, 7], "a": [8, 9]}, | |||
| // ValueError | |||
| // )) | |||
| // def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||
| // with self.assertRaises(error_type): | |||
| // nest.map_structure_with_paths(lambda path, * s: 0, s1, s2) | |||
| //class NestBenchmark(test.Benchmark): | |||
| // def run_and_report(self, s1, s2, name): | |||
| // burn_iter, test_iter = 100, 30000 | |||
| // for _ in xrange(burn_iter) : | |||
| // nest.assert_same_structure(s1, s2) | |||
| // t0 = time.time() | |||
| // for _ in xrange(test_iter) : | |||
| // nest.assert_same_structure(s1, s2) | |||
| // t1 = time.time() | |||
| // self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||
| // name=name) | |||
| // def benchmark_assert_structure(self): | |||
| // s1 = (((1, 2), 3), 4, (5, 6)) | |||
| // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
| // self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||
| // s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||
| // s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||
| // self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||
| //if __name__ == "__main__": | |||
| // test.main() | |||
| } | |||
| } | |||
| @@ -0,0 +1,883 @@ | |||
| # Copyright 2016 The TensorFlow Authors. All Rights Reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================== | |||
| """Tests for utilities working with arbitrarily nested structures.""" | |||
| from __future__ import absolute_import | |||
| from __future__ import division | |||
| from __future__ import print_function | |||
| import collections | |||
| import time | |||
| from absl.testing import parameterized | |||
| import numpy as np | |||
| from six.moves import xrange # pylint: disable=redefined-builtin | |||
| from tensorflow.python.framework import constant_op | |||
| from tensorflow.python.framework import dtypes | |||
| from tensorflow.python.framework import test_util | |||
| from tensorflow.python.ops import array_ops | |||
| from tensorflow.python.ops import math_ops | |||
| from tensorflow.python.platform import test | |||
| from tensorflow.python.util import nest | |||
| try: | |||
| import attr # pylint:disable=g-import-not-at-top | |||
| except ImportError: | |||
| attr = None | |||
| class _CustomMapping(collections.Mapping): | |||
| def __init__(self, *args, **kwargs): | |||
| self._wrapped = dict(*args, **kwargs) | |||
| def __getitem__(self, key): | |||
| return self._wrapped[key] | |||
| def __iter__(self): | |||
| return iter(self._wrapped) | |||
| def __len__(self): | |||
| return len(self._wrapped) | |||
| class NestTest(parameterized.TestCase, test.TestCase): | |||
| PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name | |||
| if attr: | |||
| class BadAttr(object): | |||
| """Class that has a non-iterable __attrs_attrs__.""" | |||
| __attrs_attrs__ = None | |||
| @attr.s | |||
| class SampleAttr(object): | |||
| field1 = attr.ib() | |||
| field2 = attr.ib() | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testAttrsFlattenAndPack(self): | |||
| if attr is None: | |||
| self.skipTest("attr module is unavailable.") | |||
| field_values = [1, 2] | |||
| sample_attr = NestTest.SampleAttr(*field_values) | |||
| self.assertFalse(nest._is_attrs(field_values)) | |||
| self.assertTrue(nest._is_attrs(sample_attr)) | |||
| flat = nest.flatten(sample_attr) | |||
| self.assertEqual(field_values, flat) | |||
| restructured_from_flat = nest.pack_sequence_as(sample_attr, flat) | |||
| self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr) | |||
| self.assertEqual(restructured_from_flat, sample_attr) | |||
| # Check that flatten fails if attributes are not iterable | |||
| with self.assertRaisesRegexp(TypeError, "object is not iterable"): | |||
| flat = nest.flatten(NestTest.BadAttr()) | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testFlattenAndPack(self): | |||
| structure = ((3, 4), 5, (6, 7, (9, 10), 8)) | |||
| flat = ["a", "b", "c", "d", "e", "f", "g", "h"] | |||
| self.assertEqual(nest.flatten(structure), [3, 4, 5, 6, 7, 9, 10, 8]) | |||
| self.assertEqual( | |||
| nest.pack_sequence_as(structure, flat), (("a", "b"), "c", | |||
| ("d", "e", ("f", "g"), "h"))) | |||
| structure = (NestTest.PointXY(x=4, y=2), | |||
| ((NestTest.PointXY(x=1, y=0),),)) | |||
| flat = [4, 2, 1, 0] | |||
| self.assertEqual(nest.flatten(structure), flat) | |||
| restructured_from_flat = nest.pack_sequence_as(structure, flat) | |||
| self.assertEqual(restructured_from_flat, structure) | |||
| self.assertEqual(restructured_from_flat[0].x, 4) | |||
| self.assertEqual(restructured_from_flat[0].y, 2) | |||
| self.assertEqual(restructured_from_flat[1][0][0].x, 1) | |||
| self.assertEqual(restructured_from_flat[1][0][0].y, 0) | |||
| self.assertEqual([5], nest.flatten(5)) | |||
| self.assertEqual([np.array([5])], nest.flatten(np.array([5]))) | |||
| self.assertEqual("a", nest.pack_sequence_as(5, ["a"])) | |||
| self.assertEqual( | |||
| np.array([5]), nest.pack_sequence_as("scalar", [np.array([5])])) | |||
| with self.assertRaisesRegexp(ValueError, "Structure is a scalar"): | |||
| nest.pack_sequence_as("scalar", [4, 5]) | |||
| with self.assertRaisesRegexp(TypeError, "flat_sequence"): | |||
| nest.pack_sequence_as([4, 5], "bad_sequence") | |||
| with self.assertRaises(ValueError): | |||
| nest.pack_sequence_as([5, 6, [7, 8]], ["a", "b", "c"]) | |||
| @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||
| {"mapping_type": _CustomMapping}) | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testFlattenDictOrder(self, mapping_type): | |||
| """`flatten` orders dicts by key, including OrderedDicts.""" | |||
| ordered = mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]) | |||
| plain = {"d": 3, "b": 1, "a": 0, "c": 2} | |||
| ordered_flat = nest.flatten(ordered) | |||
| plain_flat = nest.flatten(plain) | |||
| self.assertEqual([0, 1, 2, 3], ordered_flat) | |||
| self.assertEqual([0, 1, 2, 3], plain_flat) | |||
| @parameterized.parameters({"mapping_type": collections.OrderedDict}, | |||
| {"mapping_type": _CustomMapping}) | |||
| def testPackDictOrder(self, mapping_type): | |||
| """Packing orders dicts by key, including OrderedDicts.""" | |||
| custom = mapping_type([("d", 0), ("b", 0), ("a", 0), ("c", 0)]) | |||
| plain = {"d": 0, "b": 0, "a": 0, "c": 0} | |||
| seq = [0, 1, 2, 3] | |||
| custom_reconstruction = nest.pack_sequence_as(custom, seq) | |||
| plain_reconstruction = nest.pack_sequence_as(plain, seq) | |||
| self.assertIsInstance(custom_reconstruction, mapping_type) | |||
| self.assertIsInstance(plain_reconstruction, dict) | |||
| self.assertEqual( | |||
| mapping_type([("d", 3), ("b", 1), ("a", 0), ("c", 2)]), | |||
| custom_reconstruction) | |||
| self.assertEqual({"d": 3, "b": 1, "a": 0, "c": 2}, plain_reconstruction) | |||
| Abc = collections.namedtuple("A", ("b", "c")) # pylint: disable=invalid-name | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testFlattenAndPack_withDicts(self): | |||
| # A nice messy mix of tuples, lists, dicts, and `OrderedDict`s. | |||
| mess = [ | |||
| "z", | |||
| NestTest.Abc(3, 4), { | |||
| "d": _CustomMapping({ | |||
| 41: 4 | |||
| }), | |||
| "c": [ | |||
| 1, | |||
| collections.OrderedDict([ | |||
| ("b", 3), | |||
| ("a", 2), | |||
| ]), | |||
| ], | |||
| "b": 5 | |||
| }, 17 | |||
| ] | |||
| flattened = nest.flatten(mess) | |||
| self.assertEqual(flattened, ["z", 3, 4, 5, 1, 2, 3, 4, 17]) | |||
| structure_of_mess = [ | |||
| 14, | |||
| NestTest.Abc("a", True), | |||
| { | |||
| "d": _CustomMapping({ | |||
| 41: 42 | |||
| }), | |||
| "c": [ | |||
| 0, | |||
| collections.OrderedDict([ | |||
| ("b", 9), | |||
| ("a", 8), | |||
| ]), | |||
| ], | |||
| "b": 3 | |||
| }, | |||
| "hi everybody", | |||
| ] | |||
| unflattened = nest.pack_sequence_as(structure_of_mess, flattened) | |||
| self.assertEqual(unflattened, mess) | |||
| # Check also that the OrderedDict was created, with the correct key order. | |||
| unflattened_ordered_dict = unflattened[2]["c"][1] | |||
| self.assertIsInstance(unflattened_ordered_dict, collections.OrderedDict) | |||
| self.assertEqual(list(unflattened_ordered_dict.keys()), ["b", "a"]) | |||
| unflattened_custom_mapping = unflattened[2]["d"] | |||
| self.assertIsInstance(unflattened_custom_mapping, _CustomMapping) | |||
| self.assertEqual(list(unflattened_custom_mapping.keys()), [41]) | |||
| def testFlatten_numpyIsNotFlattened(self): | |||
| structure = np.array([1, 2, 3]) | |||
| flattened = nest.flatten(structure) | |||
| self.assertEqual(len(flattened), 1) | |||
| def testFlatten_stringIsNotFlattened(self): | |||
| structure = "lots of letters" | |||
| flattened = nest.flatten(structure) | |||
| self.assertEqual(len(flattened), 1) | |||
| unflattened = nest.pack_sequence_as("goodbye", flattened) | |||
| self.assertEqual(structure, unflattened) | |||
| def testPackSequenceAs_notIterableError(self): | |||
| with self.assertRaisesRegexp(TypeError, | |||
| "flat_sequence must be a sequence"): | |||
| nest.pack_sequence_as("hi", "bye") | |||
| def testPackSequenceAs_wrongLengthsError(self): | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| "Structure had 2 elements, but flat_sequence had 3 elements."): | |||
| nest.pack_sequence_as(["hello", "world"], | |||
| ["and", "goodbye", "again"]) | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testIsSequence(self): | |||
| self.assertFalse(nest.is_sequence("1234")) | |||
| self.assertTrue(nest.is_sequence([1, 3, [4, 5]])) | |||
| self.assertTrue(nest.is_sequence(((7, 8), (5, 6)))) | |||
| self.assertTrue(nest.is_sequence([])) | |||
| self.assertTrue(nest.is_sequence({"a": 1, "b": 2})) | |||
| self.assertFalse(nest.is_sequence(set([1, 2]))) | |||
| ones = array_ops.ones([2, 3]) | |||
| self.assertFalse(nest.is_sequence(ones)) | |||
| self.assertFalse(nest.is_sequence(math_ops.tanh(ones))) | |||
| self.assertFalse(nest.is_sequence(np.ones((4, 5)))) | |||
| @parameterized.parameters({"mapping_type": _CustomMapping}, | |||
| {"mapping_type": dict}) | |||
| def testFlattenDictItems(self, mapping_type): | |||
| dictionary = mapping_type({(4, 5, (6, 8)): ("a", "b", ("c", "d"))}) | |||
| flat = {4: "a", 5: "b", 6: "c", 8: "d"} | |||
| self.assertEqual(nest.flatten_dict_items(dictionary), flat) | |||
| with self.assertRaises(TypeError): | |||
| nest.flatten_dict_items(4) | |||
| bad_dictionary = mapping_type({(4, 5, (4, 8)): ("a", "b", ("c", "d"))}) | |||
| with self.assertRaisesRegexp(ValueError, "not unique"): | |||
| nest.flatten_dict_items(bad_dictionary) | |||
| another_bad_dictionary = mapping_type({ | |||
| (4, 5, (6, 8)): ("a", "b", ("c", ("d", "e"))) | |||
| }) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, "Key had [0-9]* elements, but value had [0-9]* elements"): | |||
| nest.flatten_dict_items(another_bad_dictionary) | |||
| # pylint does not correctly recognize these as class names and | |||
| # suggests to use variable style under_score naming. | |||
| # pylint: disable=invalid-name | |||
| Named0ab = collections.namedtuple("named_0", ("a", "b")) | |||
| Named1ab = collections.namedtuple("named_1", ("a", "b")) | |||
| SameNameab = collections.namedtuple("same_name", ("a", "b")) | |||
| SameNameab2 = collections.namedtuple("same_name", ("a", "b")) | |||
| SameNamexy = collections.namedtuple("same_name", ("x", "y")) | |||
| SameName1xy = collections.namedtuple("same_name_1", ("x", "y")) | |||
| SameName1xy2 = collections.namedtuple("same_name_1", ("x", "y")) | |||
| NotSameName = collections.namedtuple("not_same_name", ("a", "b")) | |||
| # pylint: enable=invalid-name | |||
| class SameNamedType1(SameNameab): | |||
| pass | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testAssertSameStructure(self): | |||
| structure1 = (((1, 2), 3), 4, (5, 6)) | |||
| structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
| structure_different_num_elements = ("spam", "eggs") | |||
| structure_different_nesting = (((1, 2), 3), 4, 5, (6,)) | |||
| nest.assert_same_structure(structure1, structure2) | |||
| nest.assert_same_structure("abc", 1.0) | |||
| nest.assert_same_structure("abc", np.array([0, 1])) | |||
| nest.assert_same_structure("abc", constant_op.constant([0, 1])) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("The two structures don't have the same nested structure\\.\n\n" | |||
| "First structure:.*?\n\n" | |||
| "Second structure:.*\n\n" | |||
| "More specifically: Substructure " | |||
| r'"type=tuple str=\(\(1, 2\), 3\)" is a sequence, while ' | |||
| 'substructure "type=str str=spam" is not\n' | |||
| "Entire first structure:\n" | |||
| r"\(\(\(\., \.\), \.\), \., \(\., \.\)\)\n" | |||
| "Entire second structure:\n" | |||
| r"\(\., \.\)")): | |||
| nest.assert_same_structure(structure1, structure_different_num_elements) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("The two structures don't have the same nested structure\\.\n\n" | |||
| "First structure:.*?\n\n" | |||
| "Second structure:.*\n\n" | |||
| r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
| r'is a sequence, while substructure "type=ndarray str=\[0 1\]" ' | |||
| "is not")): | |||
| nest.assert_same_structure([0, 1], np.array([0, 1])) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("The two structures don't have the same nested structure\\.\n\n" | |||
| "First structure:.*?\n\n" | |||
| "Second structure:.*\n\n" | |||
| r'More specifically: Substructure "type=list str=\[0, 1\]" ' | |||
| 'is a sequence, while substructure "type=int str=0" ' | |||
| "is not")): | |||
| nest.assert_same_structure(0, [0, 1]) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), [0, 1]) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("don't have the same nested structure\\.\n\n" | |||
| "First structure: .*?\n\nSecond structure: ")): | |||
| nest.assert_same_structure(structure1, structure_different_nesting) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, (0, 1), | |||
| NestTest.Named0ab("a", "b")) | |||
| nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
| NestTest.Named0ab("a", "b")) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||
| NestTest.Named0ab(3, 4), NestTest.Named1ab(3, 4)) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("don't have the same nested structure\\.\n\n" | |||
| "First structure: .*?\n\nSecond structure: ")): | |||
| nest.assert_same_structure(NestTest.Named0ab(3, 4), | |||
| NestTest.Named0ab([3], 4)) | |||
| with self.assertRaisesRegexp( | |||
| ValueError, | |||
| ("don't have the same nested structure\\.\n\n" | |||
| "First structure: .*?\n\nSecond structure: ")): | |||
| nest.assert_same_structure([[3], 4], [3, [4]]) | |||
| structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
| with self.assertRaisesRegexp(TypeError, | |||
| "don't have the same sequence type"): | |||
| nest.assert_same_structure(structure1, structure1_list) | |||
| nest.assert_same_structure(structure1, structure2, check_types=False) | |||
| nest.assert_same_structure(structure1, structure1_list, check_types=False) | |||
| with self.assertRaisesRegexp(ValueError, | |||
| "don't have the same set of keys"): | |||
| nest.assert_same_structure({"a": 1}, {"b": 1}) | |||
| nest.assert_same_structure(NestTest.SameNameab(0, 1), | |||
| NestTest.SameNameab2(2, 3)) | |||
| # This assertion is expected to pass: two namedtuples with the same | |||
| # name and field names are considered to be identical. | |||
| nest.assert_same_structure( | |||
| NestTest.SameNameab(NestTest.SameName1xy(0, 1), 2), | |||
| NestTest.SameNameab2(NestTest.SameName1xy2(2, 3), 4)) | |||
| expected_message = "The two structures don't have the same.*" | |||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||
| nest.assert_same_structure( | |||
| NestTest.SameNameab(0, NestTest.SameNameab2(1, 2)), | |||
| NestTest.SameNameab2(NestTest.SameNameab(0, 1), 2)) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||
| NestTest.SameNameab(0, 1), NestTest.NotSameName(2, 3)) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||
| NestTest.SameNameab(0, 1), NestTest.SameNamexy(2, 3)) | |||
| self.assertRaises(TypeError, nest.assert_same_structure, | |||
| NestTest.SameNameab(0, 1), NestTest.SameNamedType1(2, 3)) | |||
| EmptyNT = collections.namedtuple("empty_nt", "") # pylint: disable=invalid-name | |||
| def testHeterogeneousComparison(self): | |||
| nest.assert_same_structure({"a": 4}, _CustomMapping(a=3)) | |||
| nest.assert_same_structure(_CustomMapping(b=3), {"b": 4}) | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testMapStructure(self): | |||
| structure1 = (((1, 2), 3), 4, (5, 6)) | |||
| structure2 = (((7, 8), 9), 10, (11, 12)) | |||
| structure1_plus1 = nest.map_structure(lambda x: x + 1, structure1) | |||
| nest.assert_same_structure(structure1, structure1_plus1) | |||
| self.assertAllEqual( | |||
| [2, 3, 4, 5, 6, 7], | |||
| nest.flatten(structure1_plus1)) | |||
| structure1_plus_structure2 = nest.map_structure( | |||
| lambda x, y: x + y, structure1, structure2) | |||
| self.assertEqual( | |||
| (((1 + 7, 2 + 8), 3 + 9), 4 + 10, (5 + 11, 6 + 12)), | |||
| structure1_plus_structure2) | |||
| self.assertEqual(3, nest.map_structure(lambda x: x - 1, 4)) | |||
| self.assertEqual(7, nest.map_structure(lambda x, y: x + y, 3, 4)) | |||
| # Empty structures | |||
| self.assertEqual((), nest.map_structure(lambda x: x + 1, ())) | |||
| self.assertEqual([], nest.map_structure(lambda x: x + 1, [])) | |||
| self.assertEqual({}, nest.map_structure(lambda x: x + 1, {})) | |||
| self.assertEqual(NestTest.EmptyNT(), nest.map_structure(lambda x: x + 1, | |||
| NestTest.EmptyNT())) | |||
| # This is checking actual equality of types, empty list != empty tuple | |||
| self.assertNotEqual((), nest.map_structure(lambda x: x + 1, [])) | |||
| with self.assertRaisesRegexp(TypeError, "callable"): | |||
| nest.map_structure("bad", structure1_plus1) | |||
| with self.assertRaisesRegexp(ValueError, "at least one structure"): | |||
| nest.map_structure(lambda x: x) | |||
| with self.assertRaisesRegexp(ValueError, "same number of elements"): | |||
| nest.map_structure(lambda x, y: None, (3, 4), (3, 4, 5)) | |||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| nest.map_structure(lambda x, y: None, 3, (3,)) | |||
| with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), [(3, 4), 5]) | |||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5))) | |||
| structure1_list = [[[1, 2], 3], 4, [5, 6]] | |||
| with self.assertRaisesRegexp(TypeError, "same sequence type"): | |||
| nest.map_structure(lambda x, y: None, structure1, structure1_list) | |||
| nest.map_structure(lambda x, y: None, structure1, structure1_list, | |||
| check_types=False) | |||
| with self.assertRaisesRegexp(ValueError, "same nested structure"): | |||
| nest.map_structure(lambda x, y: None, ((3, 4), 5), (3, (4, 5)), | |||
| check_types=False) | |||
| with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
| nest.map_structure(lambda x: None, structure1, foo="a") | |||
| with self.assertRaisesRegexp(ValueError, "Only valid keyword argument"): | |||
| nest.map_structure(lambda x: None, structure1, check_types=False, foo="a") | |||
| ABTuple = collections.namedtuple("ab_tuple", "a, b") # pylint: disable=invalid-name | |||
| @test_util.assert_no_new_pyobjects_executing_eagerly | |||
| def testMapStructureWithStrings(self): | |||
| inp_a = NestTest.ABTuple(a="foo", b=("bar", "baz")) | |||
| inp_b = NestTest.ABTuple(a=2, b=(1, 3)) | |||
| out = nest.map_structure(lambda string, repeats: string * repeats, | |||
| inp_a, | |||
| inp_b) | |||
| self.assertEqual("foofoo", out.a) | |||
| self.assertEqual("bar", out.b[0]) | |||
| self.assertEqual("bazbazbaz", out.b[1]) | |||
| nt = NestTest.ABTuple(a=("something", "something_else"), | |||
| b="yet another thing") | |||
| rev_nt = nest.map_structure(lambda x: x[::-1], nt) | |||
| # Check the output is the correct structure, and all strings are reversed. | |||
| nest.assert_same_structure(nt, rev_nt) | |||
| self.assertEqual(nt.a[0][::-1], rev_nt.a[0]) | |||
| self.assertEqual(nt.a[1][::-1], rev_nt.a[1]) | |||
| self.assertEqual(nt.b[::-1], rev_nt.b) | |||
| @test_util.run_deprecated_v1 | |||
| def testMapStructureOverPlaceholders(self): | |||
| inp_a = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
| array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
| inp_b = (array_ops.placeholder(dtypes.float32, shape=[3, 4]), | |||
| array_ops.placeholder(dtypes.float32, shape=[3, 7])) | |||
| output = nest.map_structure(lambda x1, x2: x1 + x2, inp_a, inp_b) | |||
| nest.assert_same_structure(output, inp_a) | |||
| self.assertShapeEqual(np.zeros((3, 4)), output[0]) | |||
| self.assertShapeEqual(np.zeros((3, 7)), output[1]) | |||
| feed_dict = { | |||
| inp_a: (np.random.randn(3, 4), np.random.randn(3, 7)), | |||
| inp_b: (np.random.randn(3, 4), np.random.randn(3, 7)) | |||
| } | |||
| with self.cached_session() as sess: | |||
| output_np = sess.run(output, feed_dict=feed_dict) | |||
| self.assertAllClose(output_np[0], | |||
| feed_dict[inp_a][0] + feed_dict[inp_b][0]) | |||
| self.assertAllClose(output_np[1], | |||
| feed_dict[inp_a][1] + feed_dict[inp_b][1]) | |||
| def testAssertShallowStructure(self): | |||
| inp_ab = ["a", "b"] | |||
| inp_abc = ["a", "b", "c"] | |||
| expected_message = ( | |||
| "The two structures don't have the same sequence length. Input " | |||
| "structure has length 2, while shallow structure has length 3.") | |||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||
| nest.assert_shallow_structure(inp_abc, inp_ab) | |||
| inp_ab1 = [(1, 1), (2, 2)] | |||
| inp_ab2 = [[1, 1], [2, 2]] | |||
| expected_message = ( | |||
| "The two structures don't have the same sequence type. Input structure " | |||
| "has type <(type|class) 'tuple'>, while shallow structure has type " | |||
| "<(type|class) 'list'>.") | |||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||
| nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
| nest.assert_shallow_structure(inp_ab2, inp_ab1, check_types=False) | |||
| inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}} | |||
| inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}} | |||
| expected_message = ( | |||
| r"The two structures don't have the same keys. Input " | |||
| r"structure has keys \['c'\], while shallow structure has " | |||
| r"keys \['d'\].") | |||
| with self.assertRaisesRegexp(ValueError, expected_message): | |||
| nest.assert_shallow_structure(inp_ab2, inp_ab1) | |||
| inp_ab = collections.OrderedDict([("a", 1), ("b", (2, 3))]) | |||
| inp_ba = collections.OrderedDict([("b", (2, 3)), ("a", 1)]) | |||
| nest.assert_shallow_structure(inp_ab, inp_ba) | |||
| # This assertion is expected to pass: two namedtuples with the same | |||
| # name and field names are considered to be identical. | |||
| inp_shallow = NestTest.SameNameab(1, 2) | |||
| inp_deep = NestTest.SameNameab2(1, [1, 2, 3]) | |||
| nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=False) | |||
| nest.assert_shallow_structure(inp_shallow, inp_deep, check_types=True) | |||
| def testFlattenUpTo(self): | |||
| # Shallow tree ends at scalar. | |||
| input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]] | |||
| shallow_tree = [[True, True], [False, True]] | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]]) | |||
| self.assertEqual(flattened_shallow_tree, [True, True, False, True]) | |||
| # Shallow tree ends at string. | |||
| input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]] | |||
| shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]] | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| input_tree_flattened = nest.flatten(input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| [("a", 1), ("b", 2), ("c", 3), ("d", 4)]) | |||
| self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4]) | |||
| # Make sure dicts are correctly flattened, yielding values, not keys. | |||
| input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]} | |||
| shallow_tree = {"a": 0, "b": 0, "d": [0, 0]} | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| [1, {"c": 2}, 3, (4, 5)]) | |||
| # Namedtuples. | |||
| ab_tuple = NestTest.ABTuple | |||
| input_tree = ab_tuple(a=[0, 1], b=2) | |||
| shallow_tree = ab_tuple(a=0, b=1) | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| [[0, 1], 2]) | |||
| # Nested dicts, OrderedDicts and namedtuples. | |||
| input_tree = collections.OrderedDict( | |||
| [("a", ab_tuple(a=[0, {"b": 1}], b=2)), | |||
| ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})]) | |||
| shallow_tree = input_tree | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4]) | |||
| shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})]) | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| [ab_tuple(a=[0, {"b": 1}], b=2), | |||
| 3, | |||
| collections.OrderedDict([("f", 4)])]) | |||
| shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)]) | |||
| input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree, | |||
| input_tree) | |||
| self.assertEqual(input_tree_flattened_as_shallow_tree, | |||
| [ab_tuple(a=[0, {"b": 1}], b=2), | |||
| {"d": 3, "e": collections.OrderedDict([("f", 4)])}]) | |||
| ## Shallow non-list edge-case. | |||
| # Using iterable elements. | |||
| input_tree = ["input_tree"] | |||
| shallow_tree = "shallow_tree" | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| input_tree = ["input_tree_0", "input_tree_1"] | |||
| shallow_tree = "shallow_tree" | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| # Using non-iterable elements. | |||
| input_tree = [0] | |||
| shallow_tree = 9 | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| input_tree = [0, 1] | |||
| shallow_tree = 9 | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| ## Both non-list edge-case. | |||
| # Using iterable elements. | |||
| input_tree = "input_tree" | |||
| shallow_tree = "shallow_tree" | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| # Using non-iterable elements. | |||
| input_tree = 0 | |||
| shallow_tree = 0 | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_input_tree, [input_tree]) | |||
| self.assertEqual(flattened_shallow_tree, [shallow_tree]) | |||
| ## Input non-list edge-case. | |||
| # Using iterable elements. | |||
| input_tree = "input_tree" | |||
| shallow_tree = ["shallow_tree"] | |||
| expected_message = ("If shallow structure is a sequence, input must also " | |||
| "be a sequence. Input has type: <(type|class) 'str'>.") | |||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| input_tree = "input_tree" | |||
| shallow_tree = ["shallow_tree_9", "shallow_tree_8"] | |||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| # Using non-iterable elements. | |||
| input_tree = 0 | |||
| shallow_tree = [9] | |||
| expected_message = ("If shallow structure is a sequence, input must also " | |||
| "be a sequence. Input has type: <(type|class) 'int'>.") | |||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| input_tree = 0 | |||
| shallow_tree = [9, 8] | |||
| with self.assertRaisesRegexp(TypeError, expected_message): | |||
| flattened_input_tree = nest.flatten_up_to(shallow_tree, input_tree) | |||
| flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree) | |||
| self.assertEqual(flattened_shallow_tree, shallow_tree) | |||
| def testMapStructureUpTo(self): | |||
| # Named tuples. | |||
| ab_tuple = collections.namedtuple("ab_tuple", "a, b") | |||
| op_tuple = collections.namedtuple("op_tuple", "add, mul") | |||
| inp_val = ab_tuple(a=2, b=3) | |||
| inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3)) | |||
| out = nest.map_structure_up_to( | |||
| inp_val, lambda val, ops: (val + ops.add) * ops.mul, inp_val, inp_ops) | |||
| self.assertEqual(out.a, 6) | |||
| self.assertEqual(out.b, 15) | |||
| # Lists. | |||
| data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]] | |||
| name_list = ["evens", ["odds", "primes"]] | |||
| out = nest.map_structure_up_to( | |||
| name_list, lambda name, sec: "first_{}_{}".format(len(sec), name), | |||
| name_list, data_list) | |||
| self.assertEqual(out, ["first_4_evens", ["first_5_odds", "first_3_primes"]]) | |||
| # Dicts. | |||
| inp_val = dict(a=2, b=3) | |||
| inp_ops = dict(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||
| out = nest.map_structure_up_to( | |||
| inp_val, | |||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| self.assertEqual(out["a"], 6) | |||
| self.assertEqual(out["b"], 15) | |||
| # Non-equal dicts. | |||
| inp_val = dict(a=2, b=3) | |||
| inp_ops = dict(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||
| with self.assertRaisesRegexp(ValueError, "same keys"): | |||
| nest.map_structure_up_to( | |||
| inp_val, | |||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| # Dict+custom mapping. | |||
| inp_val = dict(a=2, b=3) | |||
| inp_ops = _CustomMapping(a=dict(add=1, mul=2), b=dict(add=2, mul=3)) | |||
| out = nest.map_structure_up_to( | |||
| inp_val, | |||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| self.assertEqual(out["a"], 6) | |||
| self.assertEqual(out["b"], 15) | |||
| # Non-equal dict/mapping. | |||
| inp_val = dict(a=2, b=3) | |||
| inp_ops = _CustomMapping(a=dict(add=1, mul=2), c=dict(add=2, mul=3)) | |||
| with self.assertRaisesRegexp(ValueError, "same keys"): | |||
| nest.map_structure_up_to( | |||
| inp_val, | |||
| lambda val, ops: (val + ops["add"]) * ops["mul"], inp_val, inp_ops) | |||
| def testGetTraverseShallowStructure(self): | |||
| scalar_traverse_input = [3, 4, (1, 2, [0]), [5, 6], {"a": (7,)}, []] | |||
| scalar_traverse_r = nest.get_traverse_shallow_structure( | |||
| lambda s: not isinstance(s, tuple), | |||
| scalar_traverse_input) | |||
| self.assertEqual(scalar_traverse_r, | |||
| [True, True, False, [True, True], {"a": False}, []]) | |||
| nest.assert_shallow_structure(scalar_traverse_r, | |||
| scalar_traverse_input) | |||
| structure_traverse_input = [(1, [2]), ([1], 2)] | |||
| structure_traverse_r = nest.get_traverse_shallow_structure( | |||
| lambda s: (True, False) if isinstance(s, tuple) else True, | |||
| structure_traverse_input) | |||
| self.assertEqual(structure_traverse_r, | |||
| [(True, False), ([True], False)]) | |||
| nest.assert_shallow_structure(structure_traverse_r, | |||
| structure_traverse_input) | |||
| with self.assertRaisesRegexp(TypeError, "returned structure"): | |||
| nest.get_traverse_shallow_structure(lambda _: [True], 0) | |||
| with self.assertRaisesRegexp(TypeError, "returned a non-bool scalar"): | |||
| nest.get_traverse_shallow_structure(lambda _: 1, [1]) | |||
| with self.assertRaisesRegexp( | |||
| TypeError, "didn't return a depth=1 structure of bools"): | |||
| nest.get_traverse_shallow_structure(lambda _: [1], [1]) | |||
| def testYieldFlatStringPaths(self): | |||
| for inputs_expected in ({"inputs": [], "expected": []}, | |||
| {"inputs": 3, "expected": [()]}, | |||
| {"inputs": [3], "expected": [(0,)]}, | |||
| {"inputs": {"a": 3}, "expected": [("a",)]}, | |||
| {"inputs": {"a": {"b": 4}}, | |||
| "expected": [("a", "b")]}, | |||
| {"inputs": [{"a": 2}], "expected": [(0, "a")]}, | |||
| {"inputs": [{"a": [2]}], "expected": [(0, "a", 0)]}, | |||
| {"inputs": [{"a": [(23, 42)]}], | |||
| "expected": [(0, "a", 0, 0), (0, "a", 0, 1)]}, | |||
| {"inputs": [{"a": ([23], 42)}], | |||
| "expected": [(0, "a", 0, 0), (0, "a", 1)]}, | |||
| {"inputs": {"a": {"a": 2}, "c": [[[4]]]}, | |||
| "expected": [("a", "a"), ("c", 0, 0, 0)]}, | |||
| {"inputs": {"0": [{"1": 23}]}, | |||
| "expected": [("0", 0, "1")]}): | |||
| inputs = inputs_expected["inputs"] | |||
| expected = inputs_expected["expected"] | |||
| self.assertEqual(list(nest.yield_flat_paths(inputs)), expected) | |||
| def testFlattenWithStringPaths(self): | |||
| for inputs_expected in ( | |||
| {"inputs": [], "expected": []}, | |||
| {"inputs": [23, "42"], "expected": [("0", 23), ("1", "42")]}, | |||
| {"inputs": [[[[108]]]], "expected": [("0/0/0/0", 108)]}): | |||
| inputs = inputs_expected["inputs"] | |||
| expected = inputs_expected["expected"] | |||
| self.assertEqual( | |||
| nest.flatten_with_joined_string_paths(inputs, separator="/"), | |||
| expected) | |||
| # Need a separate test for namedtuple as we can't declare tuple definitions | |||
| # in the @parameterized arguments. | |||
| def testFlattenNamedTuple(self): | |||
| # pylint: disable=invalid-name | |||
| Foo = collections.namedtuple("Foo", ["a", "b"]) | |||
| Bar = collections.namedtuple("Bar", ["c", "d"]) | |||
| # pylint: enable=invalid-name | |||
| test_cases = [ | |||
| (Foo(a=3, b=Bar(c=23, d=42)), | |||
| [("a", 3), ("b/c", 23), ("b/d", 42)]), | |||
| (Foo(a=Bar(c=23, d=42), b=Bar(c=0, d="something")), | |||
| [("a/c", 23), ("a/d", 42), ("b/c", 0), ("b/d", "something")]), | |||
| (Bar(c=42, d=43), | |||
| [("c", 42), ("d", 43)]), | |||
| (Bar(c=[42], d=43), | |||
| [("c/0", 42), ("d", 43)]), | |||
| ] | |||
| for inputs, expected in test_cases: | |||
| self.assertEqual( | |||
| list(nest.flatten_with_joined_string_paths(inputs)), expected) | |||
| @parameterized.named_parameters( | |||
| ("tuples", (1, 2), (3, 4), True, (("0", 4), ("1", 6))), | |||
| ("dicts", {"a": 1, "b": 2}, {"b": 4, "a": 3}, True, | |||
| {"a": ("a", 4), "b": ("b", 6)}), | |||
| ("mixed", (1, 2), [3, 4], False, (("0", 4), ("1", 6))), | |||
| ("nested", | |||
| {"a": [2, 3], "b": [1, 2, 3]}, {"b": [5, 6, 7], "a": [8, 9]}, True, | |||
| {"a": [("a/0", 10), ("a/1", 12)], | |||
| "b": [("b/0", 6), ("b/1", 8), ("b/2", 10)]})) | |||
| def testMapWithPathsCompatibleStructures(self, s1, s2, check_types, expected): | |||
| def format_sum(path, *values): | |||
| return (path, sum(values)) | |||
| result = nest.map_structure_with_paths(format_sum, s1, s2, | |||
| check_types=check_types) | |||
| self.assertEqual(expected, result) | |||
| @parameterized.named_parameters( | |||
| ("tuples", (1, 2), (3, 4, 5), ValueError), | |||
| ("dicts", {"a": 1}, {"b": 2}, ValueError), | |||
| ("mixed", (1, 2), [3, 4], TypeError), | |||
| ("nested", | |||
| {"a": [2, 3], "b": [1, 3]}, | |||
| {"b": [5, 6, 7], "a": [8, 9]}, | |||
| ValueError | |||
| )) | |||
| def testMapWithPathsIncompatibleStructures(self, s1, s2, error_type): | |||
| with self.assertRaises(error_type): | |||
| nest.map_structure_with_paths(lambda path, *s: 0, s1, s2) | |||
| class NestBenchmark(test.Benchmark): | |||
| def run_and_report(self, s1, s2, name): | |||
| burn_iter, test_iter = 100, 30000 | |||
| for _ in xrange(burn_iter): | |||
| nest.assert_same_structure(s1, s2) | |||
| t0 = time.time() | |||
| for _ in xrange(test_iter): | |||
| nest.assert_same_structure(s1, s2) | |||
| t1 = time.time() | |||
| self.report_benchmark(iters=test_iter, wall_time=(t1 - t0) / test_iter, | |||
| name=name) | |||
| def benchmark_assert_structure(self): | |||
| s1 = (((1, 2), 3), 4, (5, 6)) | |||
| s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) | |||
| self.run_and_report(s1, s2, "assert_same_structure_6_elem") | |||
| s1 = (((1, 2), 3), 4, (5, 6)) * 10 | |||
| s2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6")) * 10 | |||
| self.run_and_report(s1, s2, "assert_same_structure_60_elem") | |||
| if __name__ == "__main__": | |||
| test.main() | |||
| @@ -84,8 +84,8 @@ namespace TensorFlowNET.UnitTest.ops_test | |||
| var op = g.get_operation_by_name("cond/myop"); | |||
| tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true); | |||
| tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||
| //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta.txt", as_text:true); | |||
| //tf.train.export_meta_graph(@"D:\dev\tensorboard\logdir\sharp.meta", as_text: false); | |||
| self.assertIsNotNone(op); | |||
| self.assertEqual(op.name, "cond/myop"); | |||