Skip to content

Commit

Permalink
downloads: fix deflate decoding and add optional zstd to accepted enc…
Browse files Browse the repository at this point in the history
…odings (#594)

* downloads: fix deflate and add optional zstd to accepted encodings

* polish

* better logging and minimal version
  • Loading branch information
adbar authored May 15, 2024
1 parent 2f66f1c commit f98f557
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 35 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_long_description():
"htmldate[speed] >= 1.8.1",
"py3langid >= 0.2.2",
"pycurl >= 7.45.3",
"zstandard >= 0.20.0",
],
"gui": [
"Gooey >= 1.0.1",
Expand Down
54 changes: 39 additions & 15 deletions tests/downloads_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,25 @@
import logging
import os
import sys
import zlib

try:
import pycurl
HAS_PYCURL = True
except ImportError:
pycurl = None
HAS_PYCURL = False

try:
import brotli
HAS_BROTLI = True
except ImportError:
brotli = None
HAS_BROTLI = False

try:
import zstandard
HAS_ZSTD = True
except ImportError:
HAS_ZSTD = False

from time import sleep
from unittest.mock import patch
Expand All @@ -38,7 +47,7 @@
add_to_compressed_dict, fetch_url,
is_live_page, load_download_buffer)
from trafilatura.settings import DEFAULT_CONFIG, args_to_extractor, use_config
from trafilatura.utils import decode_file, decode_response, load_html
from trafilatura.utils import decode_file, decode_response, handle_compressed_file, load_html

logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

Expand Down Expand Up @@ -86,7 +95,7 @@ def test_fetch():
assert _urllib3_is_live_page('https://httpbun.com/status/404') is False
assert is_live_page('https://httpbun.com/status/403') is False
# is_live pycurl tests
if pycurl is not None:
if HAS_PYCURL:
assert _pycurl_is_live_page('https://httpbun.com/status/301') is True

# fetch_url
Expand All @@ -95,15 +104,15 @@ def test_fetch():
# test if the functions default to no_ssl
# doesn't work?
# assert _send_urllib_request('https://expired.badssl.com/', False, False, DEFAULT_CONFIG) is not None
if pycurl is not None:
if HAS_PYCURL:
assert _send_pycurl_request('https://expired.badssl.com/', False, False, DEFAULT_CONFIG) is not None
# no SSL, no decoding
url = 'https://httpbun.com/status/200'
for no_ssl in (True, False):
response = _send_urllib_request('https://httpbun.com/status/200', no_ssl, True, DEFAULT_CONFIG)
assert b"200" in response.data and b"OK" in response.data # JSON
assert response.headers["x-powered-by"].startswith("httpbun")
if pycurl is not None:
if HAS_PYCURL:
response1 = _send_pycurl_request('https://httpbun.com/status/200', True, True, DEFAULT_CONFIG)
assert response1.headers["x-powered-by"].startswith("httpbun")
assert _handle_response(url, response1, False, DEFAULT_OPTS).data == _handle_response(url, response, False, DEFAULT_OPTS).data
Expand Down Expand Up @@ -137,7 +146,7 @@ def test_fetch():
res = fetch_url('https://httpbun.com/redirect/1', config=new_config)
assert res is None
# Also test max redir implementation on pycurl if available
if pycurl is not None:
if HAS_PYCURL:
assert _send_pycurl_request('https://httpbun.com/redirect/1', True, False, new_config) is None
_reset_downloads_global_objects() # reset global objects again to avoid affecting other tests

Expand All @@ -147,10 +156,12 @@ def test_config():
# default config is none
assert _parse_config(DEFAULT_CONFIG) == (None, None)
# default accept-encoding
if brotli is None:
assert DEFAULT_HEADERS['accept-encoding'].endswith(',deflate')
else:
assert DEFAULT_HEADERS['accept-encoding'].endswith(',br')
accepted = ['deflate', 'gzip']
if HAS_BROTLI:
accepted.append('br')
if HAS_ZSTD:
accepted.append('zstd')
assert sorted(DEFAULT_HEADERS['accept-encoding'].split(',')) == sorted(accepted)
# default user-agent
default = _determine_headers(DEFAULT_CONFIG)
assert default['User-Agent'] == USER_AGENT
Expand All @@ -164,19 +175,32 @@ def test_config():

def test_decode():
'''Test how responses are being decoded.'''
html_string = "<html><head/><body><div>ABC</div></body></html>"
# response type
data = b" "
assert decode_file(data) is not None
assert decode_file(b" ") is not None
# GZip
html_string = "<html><head/><body><div>ABC</div></body></html>"
gz_string = gzip.compress(html_string.encode("utf-8"))
assert handle_compressed_file(gz_string) == html_string.encode("utf-8")
assert decode_file(gz_string) == html_string
with pytest.raises(ValueError):
decode_response(gz_string)
# Deflate
deflate_string = zlib.compress(html_string.encode("utf-8"))
assert handle_compressed_file(deflate_string) == html_string.encode("utf-8")
assert decode_file(deflate_string) == html_string
# Brotli
if brotli is not None:
if HAS_BROTLI:
brotli_string = brotli.compress(html_string.encode("utf-8"))
assert handle_compressed_file(brotli_string) == html_string.encode("utf-8")
assert decode_file(brotli_string) == html_string
# ZStandard
if HAS_ZSTD:
zstd_string = zstandard.compress(html_string.encode("utf-8"))
assert handle_compressed_file(zstd_string) == html_string.encode("utf-8")
assert decode_file(zstd_string) == html_string
# errors
for bad_file in ("äöüß", b"\x1f\x8b\x08abc", b"\x28\xb5\x2f\xfdabc"):
assert handle_compressed_file(bad_file) == bad_file


def test_queue():
Expand Down
63 changes: 43 additions & 20 deletions trafilatura/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,28 @@
content filtering and language detection.
"""

import gzip
import logging
import re
import zlib

from functools import lru_cache
from gzip import decompress
from html import unescape
from itertools import islice
from unicodedata import normalize

# if brotli is installed
# response compression
try:
import brotli
HAS_BROTLI = True
except ImportError:
brotli = None
HAS_BROTLI = False

try:
import zstandard
HAS_ZSTD = True
except ImportError:
HAS_ZSTD = False

# language detection
try:
Expand Down Expand Up @@ -93,23 +101,38 @@


def handle_compressed_file(filecontent):
"""Tell if a file's magic number corresponds to the GZip format
and try to decode it. Alternatively, try Brotli if the package
is installed."""
if isinstance(filecontent, bytes):
# source: https://stackoverflow.com/questions/3703276/how-to-tell-if-a-file-is-gzip-compressed
if filecontent[:2] == b'\x1f\x8b':
# decode GZipped data
try:
filecontent = decompress(filecontent)
except (EOFError, OSError):
logging.warning('invalid GZ file')
# try brotli
elif brotli is not None:
try:
filecontent = brotli.decompress(filecontent)
except brotli.error:
pass # logging.debug('invalid Brotli file')
"""
Don't trust response headers and try to decompress a binary string
with a cascade of installed packages. Use magic numbers when available.
"""
if not isinstance(filecontent, bytes):
return filecontent

# source: https://stackoverflow.com/questions/3703276/how-to-tell-if-a-file-is-gzip-compressed
if filecontent[:3] == b"\x1f\x8b\x08":
try:
return gzip.decompress(filecontent)
except Exception: # EOFError, OSError, gzip.BadGzipFile
LOGGER.warning("invalid GZ file")
# try zstandard
if HAS_ZSTD and filecontent[:4] == b"\x28\xb5\x2f\xfd":
try:
return zstandard.decompress(filecontent) # max_output_size=???
except zstandard.ZstdError:
LOGGER.warning("invalid ZSTD file")
# try brotli
if HAS_BROTLI:
try:
return brotli.decompress(filecontent)
except brotli.error:
pass # logging.debug('invalid Brotli file')
# try zlib/deflate
try:
return zlib.decompress(filecontent)
except zlib.error:
pass

# return content unchanged if decompression failed
return filecontent


Expand Down

0 comments on commit f98f557

Please sign in to comment.