diff --git a/test/test_utils.py b/test/test_utils.py index ca36909a8..179d21cf5 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -123,6 +123,7 @@ from youtube_dl.compat import ( compat_chr, compat_etree_fromstring, compat_getenv, + compat_http_cookies, compat_os_name, compat_setenv, compat_str, @@ -132,10 +133,6 @@ from youtube_dl.compat import ( class TestUtil(unittest.TestCase): - # yt-dlp shim - def assertCountEqual(self, expected, got, msg='count should be the same'): - return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg) - def test_timeconvert(self): self.assertTrue(timeconvert('') is None) self.assertTrue(timeconvert('bougrg') is None) @@ -740,28 +737,6 @@ class TestUtil(unittest.TestCase): self.assertRaises( ValueError, multipart_encode, {b'field': b'value'}, boundary='value') - def test_dict_get(self): - FALSE_VALUES = { - 'none': None, - 'false': False, - 'zero': 0, - 'empty_string': '', - 'empty_list': [], - } - d = FALSE_VALUES.copy() - d['a'] = 42 - self.assertEqual(dict_get(d, 'a'), 42) - self.assertEqual(dict_get(d, 'b'), None) - self.assertEqual(dict_get(d, 'b', 42), 42) - self.assertEqual(dict_get(d, ('a', )), 42) - self.assertEqual(dict_get(d, ('b', 'a', )), 42) - self.assertEqual(dict_get(d, ('b', 'c', 'a', 'd', )), 42) - self.assertEqual(dict_get(d, ('b', 'c', )), None) - self.assertEqual(dict_get(d, ('b', 'c', ), 42), 42) - for key, false_value in FALSE_VALUES.items(): - self.assertEqual(dict_get(d, ('b', 'c', key, )), None) - self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value) - def test_merge_dicts(self): self.assertEqual(merge_dicts({'a': 1}, {'b': 2}), {'a': 1, 'b': 2}) self.assertEqual(merge_dicts({'a': 1}, {'a': 2}), {'a': 1}) @@ -1703,24 +1678,46 @@ Line 1 self.assertEqual(variadic('spam', allowed_types=dict), 'spam') self.assertEqual(variadic('spam', allowed_types=[dict]), 'spam') + def test_join_nonempty(self): + self.assertEqual(join_nonempty('a', 'b'), 'a-b') + self.assertEqual(join_nonempty( + 'a', 'b', 'c', 'd', + from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d') + + +class TestTraversal(unittest.TestCase): + str = compat_str + _TEST_DATA = { + 100: 100, + 1.2: 1.2, + 'str': 'str', + 'None': None, + '...': Ellipsis, + 'urls': [ + {'index': 0, 'url': 'https://www.example.com/0'}, + {'index': 1, 'url': 'https://www.example.com/1'}, + ], + 'data': ( + {'index': 2}, + {'index': 3}, + ), + 'dict': {}, + } + + # yt-dlp shim + def assertCountEqual(self, expected, got, msg='count should be the same'): + return self.assertEqual(len(tuple(expected)), len(tuple(got)), msg=msg) + + def assertMaybeCountEqual(self, *args, **kwargs): + if sys.version_info < (3, 7): + # random dict order + return self.assertCountEqual(*args, **kwargs) + else: + return self.assertEqual(*args, **kwargs) + def test_traverse_obj(self): - str = compat_str - _TEST_DATA = { - 100: 100, - 1.2: 1.2, - 'str': 'str', - 'None': None, - '...': Ellipsis, - 'urls': [ - {'index': 0, 'url': 'https://www.example.com/0'}, - {'index': 1, 'url': 'https://www.example.com/1'}, - ], - 'data': ( - {'index': 2}, - {'index': 3}, - ), - 'dict': {}, - } + str = self.str + _TEST_DATA = self._TEST_DATA # define a pukka Iterable def iter_range(stop): @@ -1771,15 +1768,19 @@ Line 1 # Test set as key (transformation/type, like `expected_type`) self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper), )), ['STR'], msg='Function in set should be a transformation') + self.assertEqual(traverse_obj(_TEST_DATA, ('fail', T(lambda _: 'const'))), 'const', + msg='Function in set should always be called') self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str))), ['str'], msg='Type in set should be a type filter') + self.assertMaybeCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str, int))), [100, 'str'], + msg='Multiple types in set should be a type filter') self.assertEqual(traverse_obj(_TEST_DATA, T(dict)), _TEST_DATA, msg='A single set should be wrapped into a path') self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str.upper))), ['STR'], msg='Transformation function should not raise') - self.assertEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))), - [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None], - msg='Function in set should be a transformation') + self.assertMaybeCountEqual(traverse_obj(_TEST_DATA, (Ellipsis, T(str_or_none))), + [item for item in map(str_or_none, _TEST_DATA.values()) if item is not None], + msg='Function in set should be a transformation') if __debug__: with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'): traverse_obj(_TEST_DATA, set()) @@ -1992,23 +1993,6 @@ Line 1 self.assertEqual(traverse_obj({}, (0, slice(1)), _traverse_string=True), [], msg='branching should result in list if `traverse_string`') - # Test is_user_input behavior - _IS_USER_INPUT_DATA = {'range8': list(range(8))} - self.assertEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3'), - _is_user_input=True), 3, - msg='allow for string indexing if `is_user_input`') - self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', '3:'), - _is_user_input=True), tuple(range(8))[3:], - msg='allow for string slice if `is_user_input`') - self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':4:2'), - _is_user_input=True), tuple(range(8))[:4:2], - msg='allow step in string slice if `is_user_input`') - self.assertCountEqual(traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':'), - _is_user_input=True), range(8), - msg='`:` should be treated as `...` if `is_user_input`') - with self.assertRaises(TypeError, msg='too many params should result in error'): - traverse_obj(_IS_USER_INPUT_DATA, ('range8', ':::'), _is_user_input=True) - # Test re.Match as input obj mobj = re.match(r'^0(12)(?P3)(4)?$', '0123') self.assertEqual(traverse_obj(mobj, Ellipsis), [x for x in mobj.groups() if x is not None], @@ -2030,14 +2014,151 @@ Line 1 self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'], msg='function on a `re.Match` should give group name as well') + # Test xml.etree.ElementTree.Element as input obj + etree = compat_etree_fromstring(''' + + + 1 + 2008 + 141100 + + + + + 4 + 2011 + 59900 + + + + 68 + 2011 + 13600 + + + + ''') + self.assertEqual(traverse_obj(etree, ''), etree, + msg='empty str key should return the element itself') + self.assertEqual(traverse_obj(etree, 'country'), list(etree), + msg='str key should return all children with that tag name') + self.assertEqual(traverse_obj(etree, Ellipsis), list(etree), + msg='`...` as key should return all children') + self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]], + msg='function as key should get element as value') + self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]], + msg='function as key should get index as key') + self.assertEqual(traverse_obj(etree, 0), etree[0], + msg='int key should return the nth child') + self.assertEqual(traverse_obj(etree, './/neighbor/@name'), + ['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'], + msg='`@` at end of path should give that attribute') + self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None], + msg='`@` at end of path should give `None`') + self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'}, + msg='`@` should give the full attribute dict') + self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'], + msg='`text()` at end of path should give the inner text') + self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'], + msg='full python xpath features should be supported') + self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein', + msg='special transformations should act on current element') + self.assertEqual(traverse_obj(etree, ('country', 0, Ellipsis, 'text()', T(int_or_none))), [1, 2008, 141100], + msg='special transformations should act on current element') + + def test_traversal_unbranching(self): + # str = self.str + _TEST_DATA = self._TEST_DATA + + self.assertEqual(traverse_obj(_TEST_DATA, [(100, 1.2), all]), [100, 1.2], + msg='`all` should give all results as list') + self.assertEqual(traverse_obj(_TEST_DATA, [(100, 1.2), any]), 100, + msg='`any` should give the first result') + self.assertEqual(traverse_obj(_TEST_DATA, [100, all]), [100], + msg='`all` should give list if non branching') + self.assertEqual(traverse_obj(_TEST_DATA, [100, any]), 100, + msg='`any` should give single item if non branching') + self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 100), all]), [100], + msg='`all` should filter `None` and empty dict') + self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 100), any]), 100, + msg='`any` should filter `None` and empty dict') + self.assertEqual(traverse_obj(_TEST_DATA, [{ + 'all': [('dict', 'None', 100, 1.2), all], + 'any': [('dict', 'None', 100, 1.2), any], + }]), {'all': [100, 1.2], 'any': 100}, + msg='`all`/`any` should apply to each dict path separately') + self.assertEqual(traverse_obj(_TEST_DATA, [{ + 'all': [('dict', 'None', 100, 1.2), all], + 'any': [('dict', 'None', 100, 1.2), any], + }], get_all=False), {'all': [100, 1.2], 'any': 100}, + msg='`all`/`any` should apply to dict regardless of `get_all`') + self.assertIs(traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, T(float)]), None, + msg='`all` should reset branching status') + self.assertIs(traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), any, T(float)]), None, + msg='`any` should reset branching status') + self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 100, 1.2), all, Ellipsis, T(float)]), [1.2], + msg='`all` should allow further branching') + self.assertEqual(traverse_obj(_TEST_DATA, [('dict', 'None', 'urls', 'data'), any, Ellipsis, 'index']), [0, 1], + msg='`any` should allow further branching') + + def test_traversal_morsel(self): + values = { + 'expires': 'a', + 'path': 'b', + 'comment': 'c', + 'domain': 'd', + 'max-age': 'e', + 'secure': 'f', + 'httponly': 'g', + 'version': 'h', + 'samesite': 'i', + } + # SameSite added in Py3.8, breaks .update for 3.5-3.7 + if sys.version_info < (3, 8): + del values['samesite'] + morsel = compat_http_cookies.Morsel() + morsel.set(str('item_key'), 'item_value', 'coded_value') + morsel.update(values) + values['key'] = str('item_key') + values['value'] = 'item_value' + values = dict((str(k), v) for k, v in values.items()) + # make test pass even without ordered dict + value_set = set(values.values()) + + for key, value in values.items(): + self.assertEqual(traverse_obj(morsel, key), value, + msg='Morsel should provide access to all values') + self.assertEqual(set(traverse_obj(morsel, Ellipsis)), value_set, + msg='`...` should yield all values') + self.assertEqual(set(traverse_obj(morsel, lambda k, v: True)), value_set, + msg='function key should yield all values') + self.assertIs(traverse_obj(morsel, [(None,), any]), morsel, + msg='Morsel should not be implicitly changed to dict on usage') + def test_get_first(self): self.assertEqual(get_first([{'a': None}, {'a': 'spam'}], 'a'), 'spam') - def test_join_nonempty(self): - self.assertEqual(join_nonempty('a', 'b'), 'a-b') - self.assertEqual(join_nonempty( - 'a', 'b', 'c', 'd', - from_dict={'a': 'c', 'c': [], 'b': 'd', 'd': None}), 'c-d') + def test_dict_get(self): + FALSE_VALUES = { + 'none': None, + 'false': False, + 'zero': 0, + 'empty_string': '', + 'empty_list': [], + } + d = FALSE_VALUES.copy() + d['a'] = 42 + self.assertEqual(dict_get(d, 'a'), 42) + self.assertEqual(dict_get(d, 'b'), None) + self.assertEqual(dict_get(d, 'b', 42), 42) + self.assertEqual(dict_get(d, ('a', )), 42) + self.assertEqual(dict_get(d, ('b', 'a', )), 42) + self.assertEqual(dict_get(d, ('b', 'c', 'a', 'd', )), 42) + self.assertEqual(dict_get(d, ('b', 'c', )), None) + self.assertEqual(dict_get(d, ('b', 'c', ), 42), 42) + for key, false_value in FALSE_VALUES.items(): + self.assertEqual(dict_get(d, ('b', 'c', key, )), None) + self.assertEqual(dict_get(d, ('b', 'c', key, ), skip_false_values=False), false_value) if __name__ == '__main__': diff --git a/youtube_dl/compat.py b/youtube_dl/compat.py index 53ff2a892..d5485c7e8 100644 --- a/youtube_dl/compat.py +++ b/youtube_dl/compat.py @@ -2719,8 +2719,14 @@ if sys.version_info < (2, 7): if isinstance(xpath, compat_str): xpath = xpath.encode('ascii') return xpath + + def compat_etree_iterfind(element, match): + for from_ in element.findall(match): + yield from_ + else: compat_xpath = lambda xpath: xpath + compat_etree_iterfind = lambda element, match: element.iterfind(match) compat_os_name = os._name if os.name == 'java' else os.name @@ -2955,7 +2961,7 @@ except ImportError: return self def __exit__(self, exc_type, exc_val, exc_tb): - return exc_val is not None and isinstance(exc_val, self._exceptions or tuple()) + return exc_type is not None and issubclass(exc_type, self._exceptions or tuple()) # subprocess.Popen context manager @@ -3308,6 +3314,7 @@ __all__ = [ 'compat_contextlib_suppress', 'compat_ctypes_WINFUNCTYPE', 'compat_etree_fromstring', + 'compat_etree_iterfind', 'compat_filter', 'compat_get_terminal_size', 'compat_getenv', diff --git a/youtube_dl/utils.py b/youtube_dl/utils.py index e1b05b307..cd4303566 100644 --- a/youtube_dl/utils.py +++ b/youtube_dl/utils.py @@ -49,11 +49,14 @@ from .compat import ( compat_cookiejar, compat_ctypes_WINFUNCTYPE, compat_datetime_timedelta_total_seconds, + compat_etree_Element, compat_etree_fromstring, + compat_etree_iterfind, compat_expanduser, compat_html_entities, compat_html_entities_html5, compat_http_client, + compat_http_cookies, compat_integer_types, compat_kwargs, compat_ncompress as ncompress, @@ -6253,15 +6256,16 @@ if __debug__: def traverse_obj(obj, *paths, **kwargs): """ - Safely traverse nested `dict`s and `Iterable`s + Safely traverse nested `dict`s and `Iterable`s, etc >>> obj = [{}, {"key": "value"}] >>> traverse_obj(obj, (1, "key")) - "value" + 'value' Each of the provided `paths` is tested and the first producing a valid result will be returned. The next path will also be tested if the path branched but no results could be found. - Supported values for traversal are `Mapping`, `Iterable` and `re.Match`. + Supported values for traversal are `Mapping`, `Iterable`, `re.Match`, `xml.etree.ElementTree` + (xpath) and `http.cookies.Morsel`. Unhelpful values (`{}`, `None`) are treated as the absence of a value and discarded. The paths will be wrapped in `variadic`, so that `'key'` is conveniently the same as `('key', )`. @@ -6269,8 +6273,9 @@ def traverse_obj(obj, *paths, **kwargs): The keys in the path can be one of: - `None`: Return the current object. - `set`: Requires the only item in the set to be a type or function, - like `{type}`/`{func}`. If a `type`, returns only values - of this type. If a function, returns `func(obj)`. + like `{type}`/`{type, type, ...}`/`{func}`. If one or more `type`s, + return only values that have one of the types. If a function, + return `func(obj)`. - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`. - `slice`: Branch out and return all values in `obj[key]`. - `Ellipsis`: Branch out and return a list of all values. @@ -6282,8 +6287,10 @@ def traverse_obj(obj, *paths, **kwargs): For `Iterable`s, `key` is the enumeration count of the value. For `re.Match`es, `key` is the group number (0 = full match) as well as additionally any group names, if given. - - `dict` Transform the current object and return a matching dict. + - `dict`: Transform the current object and return a matching dict. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. + - `any`-builtin: Take the first matching object and return it, resetting branching. + - `all`-builtin: Take all matching objects and return them as a list, resetting branching. `tuple`, `list`, and `dict` all support nested paths and branches. @@ -6299,10 +6306,8 @@ def traverse_obj(obj, *paths, **kwargs): @param get_all If `False`, return the first matching result, otherwise all matching ones. @param casesense If `False`, consider string dictionary keys as case insensitive. - The following are only meant to be used by YoutubeDL.prepare_outtmpl and are not part of the API + The following is only meant to be used by YoutubeDL.prepare_outtmpl and is not part of the API - @param _is_user_input Whether the keys are generated from user input. - If `True` strings get converted to `int`/`slice` if needed. @param _traverse_string Whether to traverse into objects as strings. If `True`, any non-compatible object will first be converted into a string and then traversed into. @@ -6322,7 +6327,6 @@ def traverse_obj(obj, *paths, **kwargs): expected_type = kwargs.get('expected_type') get_all = kwargs.get('get_all', True) casesense = kwargs.get('casesense', True) - _is_user_input = kwargs.get('_is_user_input', False) _traverse_string = kwargs.get('_traverse_string', False) # instant compat @@ -6336,10 +6340,8 @@ def traverse_obj(obj, *paths, **kwargs): type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) def lookup_or_none(v, k, getter=None): - try: + with compat_contextlib_suppress(LookupError): return getter(v, k) if getter else v[k] - except IndexError: - return None def from_iterable(iterables): # chain.from_iterable(['ABC', 'DEF']) --> A B C D E F @@ -6361,12 +6363,13 @@ def traverse_obj(obj, *paths, **kwargs): result = obj elif isinstance(key, set): - assert len(key) == 1, 'Set should only be used to wrap a single item' - item = next(iter(key)) - if isinstance(item, type): - result = obj if isinstance(obj, item) else None + assert len(key) >= 1, 'At least one item is required in a `set` key' + if all(isinstance(item, type) for item in key): + result = obj if isinstance(obj, tuple(key)) else None else: - result = try_call(item, args=(obj,)) + item = next(iter(key)) + assert len(key) == 1, 'Multiple items in a `set` key must all be types' + result = try_call(item, args=(obj,)) if not isinstance(item, type) else None elif isinstance(key, (list, tuple)): branching = True @@ -6375,9 +6378,11 @@ def traverse_obj(obj, *paths, **kwargs): elif key is Ellipsis: branching = True + if isinstance(obj, compat_http_cookies.Morsel): + obj = dict(obj, key=obj.key, value=obj.value) if isinstance(obj, compat_collections_abc.Mapping): result = obj.values() - elif is_iterable_like(obj): + elif is_iterable_like(obj, (compat_collections_abc.Iterable, compat_etree_Element)): result = obj elif isinstance(obj, compat_re_Match): result = obj.groups() @@ -6389,9 +6394,11 @@ def traverse_obj(obj, *paths, **kwargs): elif callable(key): branching = True + if isinstance(obj, compat_http_cookies.Morsel): + obj = dict(obj, key=obj.key, value=obj.value) if isinstance(obj, compat_collections_abc.Mapping): iter_obj = obj.items() - elif is_iterable_like(obj): + elif is_iterable_like(obj, (compat_collections_abc.Iterable, compat_etree_Element)): iter_obj = enumerate(obj) elif isinstance(obj, compat_re_Match): iter_obj = itertools.chain( @@ -6413,6 +6420,8 @@ def traverse_obj(obj, *paths, **kwargs): if v is not None or default is not NO_DEFAULT) or None elif isinstance(obj, compat_collections_abc.Mapping): + if isinstance(obj, compat_http_cookies.Morsel): + obj = dict(obj, key=obj.key, value=obj.value) result = (try_call(obj.get, args=(key,)) if casesense or try_call(obj.__contains__, args=(key,)) else next((v for k, v in obj.items() if casefold(k) == key), None)) @@ -6430,12 +6439,40 @@ def traverse_obj(obj, *paths, **kwargs): else: result = None if isinstance(key, (int, slice)): - if is_iterable_like(obj, compat_collections_abc.Sequence): + if is_iterable_like(obj, (compat_collections_abc.Sequence, compat_etree_Element)): branching = isinstance(key, slice) result = lookup_or_none(obj, key) elif _traverse_string: result = lookup_or_none(str(obj), key) + elif isinstance(obj, compat_etree_Element) and isinstance(key, str): + xpath, _, special = key.rpartition('/') + if not special.startswith('@') and not special.endswith('()'): + xpath = key + special = None + + # Allow abbreviations of relative paths, absolute paths error + if xpath.startswith('/'): + xpath = '.' + xpath + elif xpath and not xpath.startswith('./'): + xpath = './' + xpath + + def apply_specials(element): + if special is None: + return element + if special == '@': + return element.attrib + if special.startswith('@'): + return try_call(element.attrib.get, args=(special[1:],)) + if special == 'text()': + return element.text + raise SyntaxError('apply_specials is missing case for {0!r}'.format(special)) + + if xpath: + result = list(map(apply_specials, compat_etree_iterfind(obj, xpath))) + else: + result = apply_specials(obj) + return branching, result if branching else (result,) def lazy_last(iterable): @@ -6456,17 +6493,18 @@ def traverse_obj(obj, *paths, **kwargs): key = None for last, key in lazy_last(variadic(path, (str, bytes, dict, set))): - if _is_user_input and isinstance(key, str): - if key == ':': - key = Ellipsis - elif ':' in key: - key = slice(*map(int_or_none, key.split(':'))) - elif int_or_none(key) is not None: - key = int(key) - if not casesense and isinstance(key, str): key = compat_casefold(key) + if key in (any, all): + has_branched = False + filtered_objs = (obj for obj in objs if obj not in (None, {})) + if key is any: + objs = (next(filtered_objs, None),) + else: + objs = (list(filtered_objs),) + continue + if __debug__ and callable(key): # Verify function signature _try_bind_args(key, None, None) @@ -6505,9 +6543,9 @@ def traverse_obj(obj, *paths, **kwargs): return None if default is NO_DEFAULT else default -def T(x): - """ For use in yt-dl instead of {type} or set((type,)) """ - return set((x,)) +def T(*x): + """ For use in yt-dl instead of {type, ...} or set((type, ...)) """ + return set(x) def get_first(obj, keys, **kwargs):