index : reflector32 | |
Archlinux32 fork of reflector | gitolite user |
summaryrefslogtreecommitdiff |
-rw-r--r-- | Reflector.py | 135 |
diff --git a/Reflector.py b/Reflector.py index bb56b56..95b6798 100644 --- a/Reflector.py +++ b/Reflector.py @@ -37,6 +37,7 @@ import re import shlex import socket import subprocess +import signal import sys import tempfile import time @@ -59,6 +60,7 @@ MIRROR_URL_FORMAT = '{0}{1}/os/{2}' MIRRORLIST_ENTRY_FORMAT = "Server = " + MIRROR_URL_FORMAT + "\n" DEFAULT_CONNECTION_TIMEOUT = 5 +DEFAULT_DOWNLOAD_TIMEOUT = 5 DEFAULT_CACHE_TIMEOUT = 300 SORT_TYPES = { @@ -165,9 +167,55 @@ def count_countries(mirrors): return countries +# ------------------------ download timeout handling ------------------------- # + +class DownloadTimeout(Exception): + ''' + Download timeout exception raised by DownloadContext. + ''' + + +class DownloadTimer(): + ''' + Context manager for timing downloads with timeouts. + ''' + def __init__(self, timeout=DEFAULT_DOWNLOAD_TIMEOUT): + ''' + Args: + timeout: + The download timeout in seconds. The DownloadTimeout exception + will be raised in the context after this many seconds. + ''' + self.time = None + self.start_time = None + self.timeout = timeout + self.previous_handler = None + self.previous_timer = None + + def raise_timeout(self, signl, frame): + ''' + Raise the DownloadTimeout exception. + ''' + raise DownloadTimeout(f'Download timed out after {self.timeout} second(s).') + + def __enter__(self): + self.start_time = time.time() + if self.timeout > 0: + self.previous_handler = signal.signal(signal.SIGALRM, self.raise_timeout) + self.previous_timer = signal.alarm(self.timeout) + return self + + def __exit__(self, typ, value, traceback): + self.time = time.time() - self.start_time + if self.timeout > 0: + signal.signal(signal.SIGALRM, self.previous_handler) + signal.alarm(self.previous_timer) + self.start_time = None + + # --------------------------------- Sorting ---------------------------------- # -def sort(mirrors, by=None): # pylint: disable=invalid-name +def sort(mirrors, by=None, **kwargs): # pylint: disable=invalid-name ''' Sort mirrors by different criteria. ''' @@ -179,7 +227,7 @@ def sort(mirrors, by=None): # pylint: disable=invalid-name mirrors.sort(key=lambda m: m['last_sync'], reverse=True) elif by == 'rate': - rates = rate(mirrors) + rates = rate(mirrors, **kwargs) mirrors = sorted(mirrors, key=lambda m: rates[m['url']], reverse=True) else: @@ -193,7 +241,11 @@ def sort(mirrors, by=None): # pylint: disable=invalid-name # ---------------------------------- Rating ---------------------------------- # -def rate_rsync(db_url, connection_timeout=DEFAULT_CONNECTION_TIMEOUT): +def rate_rsync( + db_url, + connection_timeout=DEFAULT_CONNECTION_TIMEOUT, + download_timeout=DEFAULT_DOWNLOAD_TIMEOUT +): ''' Download a database via rsync and return the time and rate of the download. ''' @@ -205,43 +257,64 @@ def rate_rsync(db_url, connection_timeout=DEFAULT_CONNECTION_TIMEOUT): ] try: with tempfile.TemporaryDirectory() as tmpdir: - time_0 = time.time() - subprocess.check_call( - rsync_cmd + [tmpdir], - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL - ) - time_delta = time.time() - time_0 + with DownloadTimer(timeout=download_timeout) as timer: + subprocess.check_call( + rsync_cmd + [tmpdir], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL + ) + time_delta = timer.time size = os.path.getsize( os.path.join(tmpdir, os.path.basename(DB_SUBPATH)) ) ratio = size / time_delta return time_delta, ratio - except (subprocess.CalledProcessError, subprocess.TimeoutExpired, FileNotFoundError): + except ( + subprocess.CalledProcessError, + subprocess.TimeoutExpired, + FileNotFoundError, + DownloadTimeout + ) as err: + logger = get_logger() + logger.info('failed to rate rsync download (%s): %s', db_url, err) return 0, 0 -def rate_http(db_url, connection_timeout=DEFAULT_CONNECTION_TIMEOUT): +def rate_http( + db_url, + connection_timeout=DEFAULT_CONNECTION_TIMEOUT, + download_timeout=DEFAULT_DOWNLOAD_TIMEOUT +): ''' Download a database via any protocol supported by urlopen and return the time and rate of the download. ''' req = urllib.request.Request(url=db_url) try: - with urllib.request.urlopen(req, None, connection_timeout) as handle: - time_0 = time.time() + with urllib.request.urlopen(req, None, connection_timeout) as handle, \ + DownloadTimer(timeout=download_timeout) as timer: size = len(handle.read()) - time_delta = time.time() - time_0 + time_delta = timer.time ratio = size / time_delta return time_delta, ratio - except (OSError, urllib.error.HTTPError, http.client.HTTPException): + except ( + OSError, + urllib.error.HTTPError, + http.client.HTTPException, + DownloadTimeout + ) as err: + logger = get_logger() + logger.info('failed to rate http(s) download (%s): %s', db_url, err) return 0, 0 -def rate(mirrors, connection_timeout=DEFAULT_CONNECTION_TIMEOUT): +def rate( + mirrors, + **kwargs +): ''' Rate mirrors by timing the download of the community repo's database from - each one. + each one. Keyword arguments are passed through to rate_rsync and rate_http. ''' # Ensure that mirrors is not a generator so that its length can be determined. if not isinstance(mirrors, tuple): @@ -265,9 +338,9 @@ def rate(mirrors, connection_timeout=DEFAULT_CONNECTION_TIMEOUT): scheme = urllib.parse.urlparse(url).scheme if scheme == 'rsync': - time_delta, ratio = rate_rsync(db_url, connection_timeout) + time_delta, ratio = rate_rsync(db_url, **kwargs) else: - time_delta, ratio = rate_http(db_url, connection_timeout) + time_delta, ratio = rate_http(db_url, **kwargs) kibps = ratio / 1024.0 logger.info(fmt.format(url, kibps, time_delta)) @@ -276,7 +349,7 @@ def rate(mirrors, connection_timeout=DEFAULT_CONNECTION_TIMEOUT): return rates -# ---------------------------- MirrorStatusError ----------------------------- # +# -------------------------------- Exceptions -------------------------------- # class MirrorStatusError(Exception): ''' @@ -468,11 +541,13 @@ class MirrorStatus(): def __init__( self, connection_timeout=DEFAULT_CONNECTION_TIMEOUT, + download_timeout=DEFAULT_DOWNLOAD_TIMEOUT, cache_timeout=DEFAULT_CACHE_TIMEOUT, min_completion_pct=1.0, url=URL ): self.connection_timeout = connection_timeout + self.download_timeout = download_timeout self.cache_timeout = cache_timeout self.min_completion_pct = min_completion_pct self.url = url @@ -518,21 +593,21 @@ class MirrorStatus(): msf = MirrorStatusFilter(min_completion_pct=self.min_completion_pct, **kwargs) yield from msf.filter_mirrors(mirrors) - def sort(self, mirrors=None, **kwargs): + def sort(self, mirrors, **kwargs): ''' Sort mirrors by various criteria. ''' if mirrors is None: mirrors = self.get_mirrors() + kwargs.setdefault('connection_timeout', self.connection_timeout) + kwargs.setdefault('download_timeout', self.download_timeout) yield from sort(mirrors, **kwargs) def rate(self, mirrors=None, **kwargs): ''' Sort mirrors by download speed. ''' - if mirrors is None: - mirrors = self.get_mirrors() - yield from sort(mirrors, by='rate', **kwargs) + yield from self.sort(mirrors, by='rate', **kwargs) def get_mirrorlist(self, mirrors=None, include_country=False, cmd=None): ''' @@ -601,10 +676,10 @@ def add_arguments(parser): help='The number of seconds to wait before a connection times out. Default: %(default)s' ) -# parser.add_argument( -# '--download-timeout', type=int, metavar='n', -# help='The number of seconds to wait before a download times out. The threshold is checked after each chunk is read, so the actual timeout may take longer.' -# ) + parser.add_argument( + '--download-timeout', type=int, metavar='n', default=DEFAULT_DOWNLOAD_TIMEOUT, + help='The number of seconds to wait before a download times out. Default: %(default)s' + ) parser.add_argument( '--list-countries', action=ListCountries, nargs=0, @@ -763,7 +838,7 @@ def process_options(options, mirrorstatus=None, mirrors=None): if not mirrorstatus: mirrorstatus = MirrorStatus( connection_timeout=options.connection_timeout, - # download_timeout=options.download_timeout, + download_timeout=options.download_timeout, cache_timeout=options.cache_timeout, min_completion_pct=(options.completion_percent / 100.), url=options.url |