index : archinstall32 | |
Archlinux32 installer | gitolite user |
summaryrefslogtreecommitdiff |
-rw-r--r-- | archinstall/lib/general.py | 179 |
diff --git a/archinstall/lib/general.py b/archinstall/lib/general.py index 48de4cbe..ea0bafc9 100644 --- a/archinstall/lib/general.py +++ b/archinstall/lib/general.py @@ -9,10 +9,11 @@ import string import sys import time from datetime import datetime, date -from typing import Union -try: +from typing import Callable, Optional, Dict, Any, List, Union, Iterator + +if sys.platform == 'linux': from select import epoll, EPOLLIN, EPOLLHUP -except: +else: import select EPOLLIN = 0 EPOLLHUP = 0 @@ -22,20 +23,20 @@ except: Create a epoll() implementation that simulates the epoll() behavior. This so that the rest of the code doesn't need to worry weither we're using select() or epoll(). """ - def __init__(self): - self.sockets = {} - self.monitoring = {} + def __init__(self) -> None: + self.sockets: Dict[str, Any] = {} + self.monitoring: Dict[int, Any] = {} - def unregister(self, fileno, *args, **kwargs): + def unregister(self, fileno :int, *args :List[Any], **kwargs :Dict[str, Any]) -> None: try: del(self.monitoring[fileno]) except: pass - def register(self, fileno, *args, **kwargs): + def register(self, fileno :int, *args :int, **kwargs :Dict[str, Any]) -> None: self.monitoring[fileno] = True - def poll(self, timeout=0.05, *args, **kwargs): + def poll(self, timeout: float = 0.05, *args :str, **kwargs :Dict[str, Any]) -> List[Any]: try: return [[fileno, 1] for fileno in select.select(list(self.monitoring.keys()), [], [], timeout)[0]] except OSError: @@ -66,13 +67,13 @@ def multisplit(s, splitters): s = ns return s -def locate_binary(name): +def locate_binary(name :str) -> str: for PATH in os.environ['PATH'].split(':'): for root, folders, files in os.walk(PATH): for file in files: if file == name: return os.path.join(root, file) - break # Don't recurse + break # Don't recurse raise RequirementError(f"Binary {name} does not exist.") @@ -157,7 +158,14 @@ class UNSAFE_JSON(json.JSONEncoder, json.JSONDecoder): return super(UNSAFE_JSON, self).encode(self._encode(obj)) class SysCommandWorker: - def __init__(self, cmd, callbacks=None, peak_output=False, environment_vars=None, logfile=None, working_directory='./'): + def __init__(self, + cmd :Union[str, List[str]], + callbacks :Optional[Dict[str, Any]] = None, + peak_output :Optional[bool] = False, + environment_vars :Optional[Dict[str, Any]] = None, + logfile :Optional[None] = None, + working_directory :Optional[str] = './'): + if not callbacks: callbacks = {} if not environment_vars: @@ -166,6 +174,7 @@ class SysCommandWorker: if type(cmd) is str: cmd = shlex.split(cmd) + cmd = list(cmd) # This is to please mypy if cmd[0][0] != '/' and cmd[0][:2] != './': # "which" doesn't work as it's a builtin to bash. # It used to work, but for whatever reason it doesn't anymore. @@ -179,15 +188,15 @@ class SysCommandWorker: self.logfile = logfile self.working_directory = working_directory - self.exit_code = None + self.exit_code :Optional[int] = None self._trace_log = b'' self._trace_log_pos = 0 self.poll_object = epoll() - self.child_fd = None - self.started = None - self.ended = None + self.child_fd :Optional[int] = None + self.started :Optional[float] = None + self.ended :Optional[float] = None - def __contains__(self, key: bytes): + def __contains__(self, key: bytes) -> bool: """ Contains will also move the current buffert position forward. This is to avoid re-checking the same data when looking for output. @@ -199,21 +208,21 @@ class SysCommandWorker: return contains - def __iter__(self, *args, **kwargs): + def __iter__(self, *args :str, **kwargs :Dict[str, Any]) -> Iterator[bytes]: for line in self._trace_log[self._trace_log_pos:self._trace_log.rfind(b'\n')].split(b'\n'): if line: yield line + b'\n' self._trace_log_pos = self._trace_log.rfind(b'\n') - def __repr__(self): + def __repr__(self) -> str: self.make_sure_we_are_executing() return str(self._trace_log) - def __enter__(self): + def __enter__(self) -> 'SysCommandWorker': return self - def __exit__(self, *args): + def __exit__(self, *args :str) -> None: # b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync. # TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager @@ -233,9 +242,9 @@ class SysCommandWorker: log(args[1], level=logging.ERROR, fg='red') if self.exit_code != 0: - raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}") + raise SysCallError(f"{self.cmd} exited with abnormal exit code: {self.exit_code}", self.exit_code) - def is_alive(self): + def is_alive(self) -> bool: self.poll() if self.started and self.ended is None: @@ -243,22 +252,26 @@ class SysCommandWorker: return False - def write(self, data: bytes, line_ending=True): + def write(self, data: bytes, line_ending :bool = True) -> int: assert type(data) == bytes # TODO: Maybe we can support str as well and encode it self.make_sure_we_are_executing() - os.write(self.child_fd, data + (b'\n' if line_ending else b'')) + if self.child_fd: + return os.write(self.child_fd, data + (b'\n' if line_ending else b'')) - def make_sure_we_are_executing(self): + return 0 + + def make_sure_we_are_executing(self) -> bool: if not self.started: return self.execute() + return True def tell(self) -> int: self.make_sure_we_are_executing() return self._trace_log_pos - def seek(self, pos): + def seek(self, pos :int) -> None: self.make_sure_we_are_executing() # Safety check to ensure 0 < pos < len(tracelog) self._trace_log_pos = min(max(0, pos), len(self._trace_log)) @@ -271,39 +284,41 @@ class SysCommandWorker: except UnicodeDecodeError: return False - sys.stdout.write(output) + sys.stdout.write(str(output)) sys.stdout.flush() + return True - def poll(self): + def poll(self) -> None: self.make_sure_we_are_executing() - got_output = False - for fileno, event in self.poll_object.poll(0.1): - try: - output = os.read(self.child_fd, 8192) - got_output = True - self.peak(output) - self._trace_log += output - except OSError: + if self.child_fd: + got_output = False + for fileno, event in self.poll_object.poll(0.1): + try: + output = os.read(self.child_fd, 8192) + got_output = True + self.peak(output) + self._trace_log += output + except OSError: + self.ended = time.time() + break + + if self.ended or (got_output is False and pid_exists(self.pid) is False): self.ended = time.time() - break - - if self.ended or (got_output is False and pid_exists(self.pid) is False): - self.ended = time.time() - try: - self.exit_code = os.waitpid(self.pid, 0)[1] - except ChildProcessError: try: - self.exit_code = os.waitpid(self.child_fd, 0)[1] + self.exit_code = os.waitpid(self.pid, 0)[1] except ChildProcessError: - self.exit_code = 1 + try: + self.exit_code = os.waitpid(self.child_fd, 0)[1] + except ChildProcessError: + self.exit_code = 1 def execute(self) -> bool: import pty if (old_dir := os.getcwd()) != self.working_directory: - os.chdir(self.working_directory) + os.chdir(str(self.working_directory)) # Note: If for any reason, we get a Python exception between here # and until os.close(), the traceback will get locked inside @@ -320,7 +335,7 @@ class SysCommandWorker: except PermissionError: pass - os.execve(self.cmd[0], self.cmd, {**os.environ, **self.environment_vars}) + os.execve(self.cmd[0], list(self.cmd), {**os.environ, **self.environment_vars}) if storage['arguments'].get('debug'): log(f"Executing: {self.cmd}", level=logging.DEBUG) @@ -334,15 +349,23 @@ class SysCommandWorker: return True - def decode(self, encoding='UTF-8'): + def decode(self, encoding :str = 'UTF-8') -> str: return self._trace_log.decode(encoding) class SysCommand: - def __init__(self, cmd, callback=None, start_callback=None, peak_output=False, environment_vars=None, working_directory='./'): + def __init__(self, + cmd :Union[str, List[str]], + callbacks :Optional[Dict[str, Callable[[Any], Any]]] = None, + start_callback :Optional[Callable[[Any], Any]] = None, + peak_output :Optional[bool] = False, + environment_vars :Optional[Dict[str, Any]] = None, + working_directory :Optional[str] = './'): + _callbacks = {} - if callback: - _callbacks['on_end'] = callback + if callbacks: + for hook, func in callbacks.items(): + _callbacks[hook] = func if start_callback: _callbacks['on_start'] = start_callback @@ -352,26 +375,28 @@ class SysCommand: self.environment_vars = environment_vars self.working_directory = working_directory - self.session = None + self.session :Optional[SysCommandWorker] = None self.create_session() - def __enter__(self): + def __enter__(self) -> Optional[SysCommandWorker]: return self.session - def __exit__(self, *args, **kwargs): + def __exit__(self, *args :str, **kwargs :Dict[str, Any]) -> None: # b''.join(sys_command('sync')) # No need to, since the underlying fs() object will call sync. # TODO: https://stackoverflow.com/questions/28157929/how-to-safely-handle-an-exception-inside-a-context-manager if len(args) >= 2 and args[1]: log(args[1], level=logging.ERROR, fg='red') - def __iter__(self, *args, **kwargs): - - for line in self.session: - yield line + def __iter__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> Iterator[bytes]: + if self.session: + for line in self.session: + yield line - def __getitem__(self, key): - if type(key) is slice: + def __getitem__(self, key :slice) -> Optional[bytes]: + if not self.session: + raise KeyError(f"SysCommand() does not have an active session.") + elif type(key) is slice: start = key.start if key.start else 0 end = key.stop if key.stop else len(self.session._trace_log) @@ -379,10 +404,12 @@ class SysCommand: else: raise ValueError("SysCommand() doesn't have key & value pairs, only slices, SysCommand('ls')[:10] as an example.") - def __repr__(self, *args, **kwargs): - return self.session._trace_log.decode('UTF-8') + def __repr__(self, *args :List[Any], **kwargs :Dict[str, Any]) -> str: + if self.session: + return self.session._trace_log.decode('UTF-8') + return '' - def __json__(self): + def __json__(self) -> Dict[str, Union[str, bool, List[str], Dict[str, Any], Optional[bool], Optional[Dict[str, Any]]]]: return { 'cmd': self.cmd, 'callbacks': self._callbacks, @@ -391,7 +418,7 @@ class SysCommand: 'session': True if self.session else False } - def create_session(self): + def create_session(self) -> bool: if self.session: return True @@ -406,16 +433,23 @@ class SysCommand: return True - def decode(self, fmt='UTF-8'): - return self.session._trace_log.decode(fmt) + def decode(self, fmt :str = 'UTF-8') -> Optional[str]: + if self.session: + return self.session._trace_log.decode(fmt) + return None @property - def exit_code(self): - return self.session.exit_code + def exit_code(self) -> Optional[int]: + if self.session: + return self.session.exit_code + else: + return None @property - def trace_log(self): - return self.session._trace_log + def trace_log(self) -> Optional[bytes]: + if self.session: + return self.session._trace_log + return None def prerequisite_check(): @@ -428,7 +462,8 @@ def prerequisite_check(): def reboot(): SysCommand("/usr/bin/reboot") -def pid_exists(pid: int): + +def pid_exists(pid: int) -> bool: try: return any(subprocess.check_output(['/usr/bin/ps', '--no-headers', '-o', 'pid', '-p', str(pid)]).strip()) except subprocess.CalledProcessError: |