from __future__ import annotations

import dataclasses
import json
import math
import time
import uuid
from dataclasses import dataclass, field
from enum import Enum
from enum import auto
from pathlib import Path
from typing import Optional, List, Dict, TYPE_CHECKING, Any
from typing import Union

import parted  # type: ignore
import _ped  # type: ignore
from parted import Disk, Geometry, Partition

from ..exceptions import DiskError, SysCallError
from ..general import SysCommand
from ..output import debug, error
from import storage
from ..output import info

	_: Any

class DiskLayoutType(Enum):
	Default = 'default_layout'
	Manual = 'manual_partitioning'
	Pre_mount = 'pre_mounted_config'

	def display_msg(self) -> str:
		match self:
			case DiskLayoutType.Default: return str(_('Use a best-effort default partition layout'))
			case DiskLayoutType.Manual: return str(_('Manual Partitioning'))
			case DiskLayoutType.Pre_mount: return str(_('Pre-mounted configuration'))

class DiskLayoutConfiguration:
	config_type: DiskLayoutType
	device_modifications: List[DeviceModification] = field(default_factory=list)
	# used for pre-mounted config
	relative_mountpoint: Optional[Path] = None

	def __post_init__(self):
		if self.config_type == DiskLayoutType.Pre_mount and self.relative_mountpoint is None:
			raise ValueError('Must set a relative mountpoint when layout type is pre-mount"')

	def __dump__(self) -> Dict[str, Any]:
		return {
			'config_type': self.config_type.value,
			'device_modifications': [mod.__dump__() for mod in self.device_modifications]

	def parse_arg(cls, disk_config: Dict[str, List[Dict[str, Any]]]) -> Optional[DiskLayoutConfiguration]:
		from .device_handler import device_handler

		device_modifications: List[DeviceModification] = []
		config_type = disk_config.get('config_type', None)

		if not config_type:
			raise ValueError('Missing disk layout configuration: config_type')

		config = DiskLayoutConfiguration(

		for entry in disk_config.get('device_modifications', []):
			device_path = Path(entry.get('device', None)) if entry.get('device', None) else None

			if not device_path:

			device = device_handler.get_device(device_path)

			if not device:

			device_modification = DeviceModification(
				wipe=entry.get('wipe', False),

			device_partitions: List[PartitionModification] = []

			for partition in entry.get('partitions', []):
				device_partition = PartitionModification(
					mountpoint=Path(partition['mountpoint']) if partition['mountpoint'] else None,
					dev_path=Path(partition['dev_path']) if partition['dev_path'] else None,
					flags=[PartitionFlag[f] for f in partition.get('flags', [])],
					btrfs_subvols=SubvolumeModification.parse_args(partition.get('btrfs', [])),
				# special 'invisible attr to internally identify the part mod
				setattr(device_partition, '_obj_id', partition['obj_id'])

			device_modification.partitions = device_partitions

		return config

class PartitionTable(Enum):
	GPT = 'gpt'
	MBR = 'msdos'

class Unit(Enum):
	B = 1          # byte
	kB = 1000**1   # kilobyte
	MB = 1000**2   # megabyte
	GB = 1000**3   # gigabyte
	TB = 1000**4   # terabyte
	PB = 1000**5   # petabyte
	EB = 1000**6   # exabyte
	ZB = 1000**7   # zettabyte
	YB = 1000**8   # yottabyte

	KiB = 1024**1 	# kibibyte
	MiB = 1024**2 	# mebibyte
	GiB = 1024**3  	# gibibyte
	TiB = 1024**4  	# tebibyte
	PiB = 1024**5  	# pebibyte
	EiB = 1024**6  	# exbibyte
	ZiB = 1024**7  	# zebibyte
	YiB = 1024**8  	# yobibyte

	sectors = 'sectors'  # size in sector

	Percent = '%' 	# size in percentile

	def get_all_units() -> List[str]:
		return [ for u in Unit]

class Size:
	value: int
	unit: Unit
	sector_size: Optional[Size] = None  # only required when unit is sector
	total_size: Optional[Size] = None  # required when operating on percentages

	def __post_init__(self):
		if self.unit == Unit.sectors and self.sector_size is None:
			raise ValueError('Sector size is required when unit is sectors')
		elif self.unit == Unit.Percent:
			if self.value < 0 or self.value > 100:
				raise ValueError('Percentage must be between 0 and 100')
			elif self.total_size is None:
				raise ValueError('Total size is required when unit is percentage')

	def _total_size(self) -> Size:
		Save method to get the total size, mainly to satisfy mypy
		This shouldn't happen as the Size object fails instantiation on missing total size
		if self.unit == Unit.Percent and self.total_size is None:
			raise ValueError('Percent unit size must specify a total size')
		return self.total_size  # type: ignore

	def __dump__(self) -> Dict[str, Any]:
		return {
			'value': self.value,
			'sector_size': self.sector_size.__dump__() if self.sector_size else None,
			'total_size': self._total_size.__dump__() if self._total_size else None

	def parse_args(cls, size_arg: Dict[str, Any]) -> Size:
		sector_size = size_arg['sector_size']
		total_size = size_arg['total_size']

		return Size(
			Size.parse_args(sector_size) if sector_size else None,
			Size.parse_args(total_size) if total_size else None

	def convert(
		target_unit: Unit,
		sector_size: Optional[Size] = None,
		total_size: Optional[Size] = None
	) -> Size:
		if target_unit == Unit.sectors and sector_size is None:
			raise ValueError('If target has unit sector, a sector size must be provided')

		# not sure why we would ever wanna convert to percentages
		if target_unit == Unit.Percent and total_size is None:
			raise ValueError('Missing paramter total size to be able to convert to percentage')

		if self.unit == target_unit:
			return self
		elif self.unit == Unit.Percent:
			amount = int(self._total_size._normalize() * (self.value / 100))
			return Size(amount, Unit.B)
		elif self.unit == Unit.sectors:
			norm = self._normalize()
			return Size(norm, Unit.B).convert(target_unit, sector_size)
			if target_unit == Unit.sectors and sector_size is not None:
				norm = self._normalize()
				sectors = math.ceil(norm / sector_size.value)
				return Size(sectors, Unit.sectors, sector_size)
				value = int(self._normalize() / target_unit.value)  # type: ignore
				return Size(value, target_unit)

	def as_text(self) -> str:
		return self.format_size(

	def format_size(
		target_unit: Unit,
		sector_size: Optional[Size] = None,
		include_unit: bool = True
	) -> str:
		if self.unit == Unit.Percent:
			return f'{self.value}%'
			target_size = self.convert(target_unit, sector_size)
			if include_unit:
				return f'{target_size.value} {}'
			return f'{target_size.value}'

	def _normalize(self) -> int:
		will normalize the value of the unit to Byte
		if self.unit == Unit.Percent:
			return self.convert(Unit.B).value
		elif self.unit == Unit.sectors and self.sector_size is not None:
			return self.value * self.sector_size._normalize()
		return int(self.value * self.unit.value)  # type: ignore

	def __sub__(self, other: Size) -> Size:
		src_norm = self._normalize()
		dest_norm = other._normalize()
		return Size(abs(src_norm - dest_norm), Unit.B)

	def __lt__(self, other):
		return self._normalize() < other._normalize()

	def __le__(self, other):
		return self._normalize() <= other._normalize()

	def __eq__(self, other):
		return self._normalize() == other._normalize()

	def __ne__(self, other):
		return self._normalize() != other._normalize()

	def __gt__(self, other):
		return self._normalize() > other._normalize()

	def __ge__(self, other):
		return self._normalize() >= other._normalize()

class _BtrfsSubvolumeInfo:
	name: Path
	mountpoint: Optional[Path]

class _PartitionInfo:
	partition: Partition
	name: str
	type: PartitionType
	fs_type: Optional[FilesystemType]
	path: Path
	start: Size
	length: Size
	flags: List[PartitionFlag]
	partuuid: str
	disk: Disk
	mountpoints: List[Path]
	btrfs_subvol_infos: List[_BtrfsSubvolumeInfo] = field(default_factory=list)

	def table_data(self) -> Dict[str, Any]:
		part_info = {
			'Type': self.type.value,
			'Filesystem': self.fs_type.value if self.fs_type else str(_('Unknown')),
			'Path': str(self.path),
			'Start': self.start.format_size(Unit.MiB),
			'Length': self.length.format_size(Unit.MiB),
			'Flags': ', '.join([ for f in self.flags])

		if self.btrfs_subvol_infos:
			part_info['Btrfs vol.'] = f'{len(self.btrfs_subvol_infos)} subvolumes'

		return part_info

	def from_partition(
		partition: Partition,
		fs_type: Optional[FilesystemType],
		partuuid: str,
		mountpoints: List[Path],
		btrfs_subvol_infos: List[_BtrfsSubvolumeInfo] = []
	) -> _PartitionInfo:
		partition_type = PartitionType.get_type_from_code(partition.type)
		flags = [f for f in PartitionFlag if partition.getFlag(f.value)]

		start = Size(
			Size(partition.disk.device.sectorSize, Unit.B)

		length = Size(int(partition.getLength(unit='B')), Unit.B)

		return _PartitionInfo(

class _DeviceInfo:
	model: str
	path: Path
	type: str
	total_size: Size
	free_space_regions: List[DeviceGeometry]
	sector_size: Size
	read_only: bool
	dirty: bool

	def table_data(self) -> Dict[str, Any]:
		total_free_space = sum([region.get_length(unit=Unit.MiB) for region in self.free_space_regions])
		return {
			'Model': self.model,
			'Path': str(self.path),
			'Type': self.type,
			'Size': self.total_size.format_size(Unit.MiB),
			'Free space': int(total_free_space),
			'Sector size': self.sector_size.value,
			'Read only': self.read_only

	def from_disk(cls, disk: Disk) -> _DeviceInfo:
		device = disk.device
		device_type = parted.devices[device.type]

		sector_size = Size(device.sectorSize, Unit.B)
		free_space = [DeviceGeometry(g, sector_size) for g in disk.getFreeSpaceRegions()]

		return _DeviceInfo(
			total_size=Size(int(device.getLength(unit='B')), Unit.B),

class SubvolumeModification:
	name: Path
	mountpoint: Optional[Path] = None
	compress: bool = False
	nodatacow: bool = False

	def from_existing_subvol_info(cls, info: _BtrfsSubvolumeInfo) -> SubvolumeModification:
		return SubvolumeModification(, mountpoint=info.mountpoint)

	def parse_args(cls, subvol_args: List[Dict[str, Any]]) -> List[SubvolumeModification]:
		mods = []
		for entry in subvol_args:
			if not entry.get('name', None) or not entry.get('mountpoint', None):
				debug(f'Subvolume arg is missing name: {entry}')

			mountpoint = Path(entry['mountpoint']) if entry['mountpoint'] else None

					entry.get('compress', False),
					entry.get('nodatacow', False)

		return mods

	def mount_options(self) -> List[str]:
		options = []
		options += ['compress'] if self.compress else []
		options += ['nodatacow'] if self.nodatacow else []
		return options

	def relative_mountpoint(self) -> Path:
		Will return the relative path based on the anchor
		e.g. Path('/mnt/test') -> Path('mnt/test')
		if self.mountpoint is not None:
			return self.mountpoint.relative_to(self.mountpoint.anchor)

		raise ValueError('Mountpoint is not specified')

	def is_root(self, relative_mountpoint: Optional[Path] = None) -> bool:
		if self.mountpoint:
			if relative_mountpoint is not None:
				return self.mountpoint.relative_to(relative_mountpoint) == Path('.')
			return self.mountpoint == Path('/')
		return False

	def __dump__(self) -> Dict[str, Any]:
		return {
			'name': str(,
			'mountpoint': str(self.mountpoint),
			'compress': self.compress,
			'nodatacow': self.nodatacow

	def table_data(self) -> Dict[str, Any]:
		return {
			'name': str(,
			'mountpoint': str(self.mountpoint),
			'compress': self.compress,
			'nodatacow': self.nodatacow

class DeviceGeometry:
	def __init__(self, geometry: Geometry, sector_size: Size):
		self._geometry = geometry
		self._sector_size = sector_size

	def start(self) -> int:
		return self._geometry.start

	def end(self) -> int:
		return self._geometry.end

	def get_length(self, unit: Unit = Unit.sectors) -> int:
		return self._geometry.getLength(

	def table_data(self) -> Dict[str, Any]:
		start = Size(self._geometry.start, Unit.sectors, self._sector_size)
		end = Size(self._geometry.end, Unit.sectors, self._sector_size)
		length = Size(self._geometry.getLength(), Unit.sectors, self._sector_size)

		start_str = f'{self._geometry.start} / {start.format_size(Unit.B, include_unit=False)}'
		end_str = f'{self._geometry.end} / {end.format_size(Unit.B, include_unit=False)}'
		length_str = f'{self._geometry.getLength()} / {length.format_size(Unit.B, include_unit=False)}'

		return {
			'Sector size': self._sector_size.value,
			'Start (sector/B)': start_str,
			'End (sector/B)': end_str,
			'Length (sectors/B)': length_str

class BDevice:
	disk: Disk
	device_info: _DeviceInfo
	partition_infos: List[_PartitionInfo]

	def __hash__(self):
		return hash(self.disk.device.path)

class PartitionType(Enum):
	Boot = 'boot'
	Primary = 'primary'
	_Unknown = 'unknown'

	def get_type_from_code(cls, code: int) -> PartitionType:
		if code == parted.PARTITION_NORMAL:
			return PartitionType.Primary
			info(f'Partition code not supported: {code}')
			return PartitionType._Unknown

	def get_partition_code(self) -> Optional[int]:
		if self == PartitionType.Primary:
			return parted.PARTITION_NORMAL
		elif self == PartitionType.Boot:
			return parted.PARTITION_BOOT
		return None

class PartitionFlag(Enum):
	Flags are taken from _ped because pyparted uses this to look
	up their flag definitions:
	Which is the way libparted checks for its flags:;a=blob;f=libparted/labels/gpt.c;hb=4a0e468ed63fff85a1f9b923189f20945b32f4f1#l183
	XBOOTLDR = _ped.PARTITION_BLS_BOOT # Note: parted calls this bls_boot

# class PartitionGUIDs(Enum):
# 	"""
# 	A list of Partition type GUIDs (lsblk -o+PARTTYPE) can be found here:
# 	"""
# 	XBOOTLDR = 'bc13c2ff-59e6-4262-a352-b275fd6f7172'

class FilesystemType(Enum):
	Btrfs = 'btrfs'
	Ext2 = 'ext2'
	Ext3 = 'ext3'
	Ext4 = 'ext4'
	F2fs = 'f2fs'
	Fat16 = 'fat16'
	Fat32 = 'fat32'
	Ntfs = 'ntfs'
	Reiserfs = 'reiserfs'
	Xfs = 'xfs'

	# this is not a FS known to parted, so be careful
	# with the usage from this enum
	Crypto_luks = 'crypto_LUKS'

	def is_crypto(self) -> bool:
		return self == FilesystemType.Crypto_luks

	def fs_type_mount(self) -> str:
		match self:
			case FilesystemType.Ntfs: return 'ntfs3'
			case FilesystemType.Fat32: return 'vfat'
			case _: return self.value  # type: ignore

	def installation_pkg(self) -> Optional[str]:
		match self:
			case FilesystemType.Btrfs: return 'btrfs-progs'
			case FilesystemType.Xfs: return 'xfsprogs'
			case FilesystemType.F2fs: return 'f2fs-tools'
			case _: return None

	def installation_module(self) -> Optional[str]:
		match self:
			case FilesystemType.Btrfs: return 'btrfs'
			case _: return None

	def installation_binary(self) -> Optional[str]:
		match self:
			case FilesystemType.Btrfs: return '/usr/bin/btrfs'
			case _: return None

	def installation_hooks(self) -> Optional[str]:
		match self:
			case FilesystemType.Btrfs: return 'btrfs'
			case _: return None

class ModificationStatus(Enum):
	Exist = 'existing'
	Modify = 'modify'
	Delete = 'delete'
	Create = 'create'

class PartitionModification:
	status: ModificationStatus
	type: PartitionType
	start: Size
	length: Size
	fs_type: Optional[FilesystemType]
	mountpoint: Optional[Path] = None
	mount_options: List[str] = field(default_factory=list)
	flags: List[PartitionFlag] = field(default_factory=list)
	btrfs_subvols: List[SubvolumeModification] = field(default_factory=list)

	# only set if the device was created or exists
	dev_path: Optional[Path] = None
	partuuid: Optional[str] = None
	uuid: Optional[str] = None

	_boot_indicator_flags = [PartitionFlag.Boot, PartitionFlag.XBOOTLDR]

	def __post_init__(self):
		# needed to use the object as a dictionary key due to hash func
		if not hasattr(self, '_obj_id'):
			self._obj_id = uuid.uuid4()

		if self.is_exists_or_modify() and not self.dev_path:
			raise ValueError('If partition marked as existing a path must be set')

		if self.fs_type is None and self.status == ModificationStatus.Modify:
			raise ValueError('FS type must not be empty on modifications with status type modify')

	def __hash__(self):
		return hash(self._obj_id)

	def obj_id(self) -> str:
		if hasattr(self, '_obj_id'):
			return str(self._obj_id)
		return ''

	def safe_dev_path(self) -> Path:
		if self.dev_path is None:
			raise ValueError('Device path was not set')
		return self.dev_path

	def safe_fs_type(self) -> FilesystemType:
		if self.fs_type is None:
			raise ValueError('File system type is not set')
		return self.fs_type

	def from_existing_partition(cls, partition_info: _PartitionInfo) -> PartitionModification:
		if partition_info.btrfs_subvol_infos:
			mountpoint = None
			subvol_mods = []
			for i in partition_info.btrfs_subvol_infos:
			mountpoint = partition_info.mountpoints[0] if partition_info.mountpoints else None
			subvol_mods = []

		return PartitionModification(

	def relative_mountpoint(self) -> Path:
		Will return the relative path based on the anchor
		e.g. Path('/mnt/test') -> Path('mnt/test')
		if self.mountpoint:
			return self.mountpoint.relative_to(self.mountpoint.anchor)

		raise ValueError('Mountpoint is not specified')

	def is_boot(self) -> bool:
		Returns True if any of the boot indicator flags are found in self.flags
		return any(set(self.flags) & set(self._boot_indicator_flags))

	def is_root(self, relative_mountpoint: Optional[Path] = None) -> bool:
		if relative_mountpoint is not None and self.mountpoint is not None:
			return self.mountpoint.relative_to(relative_mountpoint) == Path('.')
		elif self.mountpoint is not None:
			return Path('/') == self.mountpoint
			for subvol in self.btrfs_subvols:
				if subvol.is_root(relative_mountpoint):
					return True

		return False

	def is_modify(self) -> bool:
		return self.status == ModificationStatus.Modify

	def exists(self) -> bool:
		return self.status == ModificationStatus.Exist

	def is_exists_or_modify(self) -> bool:
		return self.status in [ModificationStatus.Exist, ModificationStatus.Modify]

	def mapper_name(self) -> Optional[str]:
		if self.dev_path:
			return f'{storage.get("ENC_IDENTIFIER", "ai")}{}'
		return None

	def set_flag(self, flag: PartitionFlag):
		if flag not in self.flags:

	def invert_flag(self, flag: PartitionFlag):
		if flag in self.flags:
			self.flags = [f for f in self.flags if f != flag]

	def json(self) -> Dict[str, Any]:
		Called for configuration settings
		return {
			'obj_id': self.obj_id,
			'status': self.status.value,
			'type': self.type.value,
			'start': self.start.__dump__(),
			'length': self.length.__dump__(),
			'fs_type': self.fs_type.value if self.fs_type else '',
			'mountpoint': str(self.mountpoint) if self.mountpoint else None,
			'mount_options': self.mount_options,
			'flags': [ for f in self.flags],
			'dev_path': str(self.dev_path) if self.dev_path else None,
			'btrfs': [vol.__dump__() for vol in self.btrfs_subvols]

	def table_data(self) -> Dict[str, Any]:
		Called for displaying data in table format
		part_mod = {
			'Status': self.status.value,
			'Device': str(self.dev_path) if self.dev_path else '',
			'Type': self.type.value,
			'Start': self.start.format_size(Unit.MiB),
			'Length': self.length.format_size(Unit.MiB),
			'FS type': self.fs_type.value if self.fs_type else 'Unknown',
			'Mountpoint': self.mountpoint if self.mountpoint else '',
			'Mount options': ', '.join(self.mount_options),
			'Flags': ', '.join([ for f in self.flags]),

		if self.btrfs_subvols:
			part_mod['Btrfs vol.'] = f'{len(self.btrfs_subvols)} subvolumes'

		return part_mod

class DeviceModification:
	device: BDevice
	wipe: bool
	partitions: List[PartitionModification] = field(default_factory=list)

	def device_path(self) -> Path:
		return self.device.device_info.path

	def add_partition(self, partition: PartitionModification):

	def get_efi_partition(self) -> Optional[PartitionModification]:
		Similar to get_boot_partition() but excludes XBOOTLDR partitions from it's candidates.
		filtered = filter(lambda x: x.is_boot() and x.fs_type == FilesystemType.Fat32 and PartitionFlag.XBOOTLDR not in x.flags, self.partitions)
		return next(filtered, None)

	def get_boot_partition(self) -> Optional[PartitionModification]:
		Returns the first partition marked as XBOOTLDR (PARTTYPE id of bc13c2ff-...) or Boot and has a mountpoint.
		Only returns XBOOTLDR if separate EFI is detected using self.get_efi_partition()
		if efi_partition := self.get_efi_partition():
			filtered = filter(lambda x: x.is_boot() and x != efi_partition and x.mountpoint, self.partitions)
			if boot_partition := next(filtered, None):
				return boot_partition
			return efi_partition
			filtered = filter(lambda x: x.is_boot() and x.mountpoint, self.partitions)
			return next(filtered, None)

	def get_root_partition(self, relative_path: Optional[Path]) -> Optional[PartitionModification]:
		filtered = filter(lambda x: x.is_root(relative_path), self.partitions)
		return next(filtered, None)

	def __dump__(self) -> Dict[str, Any]:
		Called when generating configuration files
		return {
			'device': str(self.device.device_info.path),
			'wipe': self.wipe,
			'partitions': [p.json() for p in self.partitions]

class EncryptionType(Enum):
	NoEncryption = "no_encryption"
	Luks = "luks"

	def _encryption_type_mapper(cls) -> Dict[str, 'EncryptionType']:
		return {
			'Luks': EncryptionType.Luks

	def text_to_type(cls, text: str) -> 'EncryptionType':
		mapping = cls._encryption_type_mapper()
		return mapping[text]

	def type_to_text(cls, type_: 'EncryptionType') -> str:
		mapping = cls._encryption_type_mapper()
		type_to_text = {type_: text for text, type_ in mapping.items()}
		return type_to_text[type_]

class DiskEncryption:
	encryption_type: EncryptionType = EncryptionType.Luks
	encryption_password: str = ''
	partitions: List[PartitionModification] = field(default_factory=list)
	hsm_device: Optional[Fido2Device] = None

	def should_generate_encryption_file(self, part_mod: PartitionModification) -> bool:
		return part_mod in self.partitions and part_mod.mountpoint != Path('/')

	def json(self) -> Dict[str, Any]:
		obj: Dict[str, Any] = {
			'encryption_type': self.encryption_type.value,
			'partitions': [p.obj_id for p in self.partitions]

		if self.hsm_device:
			obj['hsm_device'] = self.hsm_device.json()

		return obj

	def parse_arg(
		disk_config: DiskLayoutConfiguration,
		arg: Dict[str, Any],
		password: str = ''
	) -> 'DiskEncryption':
		enc_partitions = []
		for mod in disk_config.device_modifications:
			for part in mod.partitions:
				if part.obj_id in arg.get('partitions', []):

		enc = DiskEncryption(

		if hsm := arg.get('hsm_device', None):
			enc.hsm_device = Fido2Device.parse_arg(hsm)

		return enc

class Fido2Device:
	path: Path
	manufacturer: str
	product: str

	def json(self) -> Dict[str, str]:
		return {
			'path': str(self.path),
			'manufacturer': self.manufacturer,
			'product': self.product

	def parse_arg(cls, arg: Dict[str, str]) -> 'Fido2Device':
		return Fido2Device(

class LsblkInfo:
	name: str = ''
	path: Path = Path()
	pkname: str = ''
	size: Size = field(default_factory=lambda: Size(0, Unit.B))
	log_sec: int = 0
	pttype: str = ''
	ptuuid: str = ''
	rota: bool = False
	tran: Optional[str] = None
	partuuid: Optional[str] = None
	parttype :Optional[str] = None
	uuid: Optional[str] = None
	fstype: Optional[str] = None
	fsver: Optional[str] = None
	fsavail: Optional[str] = None
	fsuse_percentage: Optional[str] = None
	type: Optional[str] = None
	mountpoint: Optional[Path] = None
	mountpoints: List[Path] = field(default_factory=list)
	fsroots: List[Path] = field(default_factory=list)
	children: List[LsblkInfo] = field(default_factory=list)

	def json(self) -> Dict[str, Any]:
		return {
			'path': str(self.path),
			'pkname': self.pkname,
			'size': self.size.format_size(Unit.MiB),
			'log_sec': self.log_sec,
			'pttype': self.pttype,
			'ptuuid': self.ptuuid,
			'rota': self.rota,
			'tran': self.tran,
			'partuuid': self.partuuid,
			'parttype' : self.parttype,
			'uuid': self.uuid,
			'fstype': self.fstype,
			'fsver': self.fsver,
			'fsavail': self.fsavail,
			'fsuse_percentage': self.fsuse_percentage,
			'type': self.type,
			'mountpoint': self.mountpoint,
			'mountpoints': [str(m) for m in self.mountpoints],
			'fsroots': [str(r) for r in self.fsroots],
			'children': [c.json() for c in self.children]

	def btrfs_subvol_info(self) -> Dict[Path, Path]:
		It is assumed that lsblk will contain the fields as

		"mountpoints": ["/mnt/archinstall/log", "/mnt/archinstall/home", "/mnt/archinstall", ...]
		"fsroots": ["/@log", "/@home", "/@"...]

		we'll thereby map the fsroot, which are the mounted filesystem roots
		to the corresponding mountpoints
		return dict(zip(self.fsroots, self.mountpoints))

	def exclude(cls) -> List[str]:
		return ['children']

	def fields(cls) -> List[str]:
		return [ for f in dataclasses.fields(LsblkInfo) if not in cls.exclude()]

	def from_json(cls, blockdevice: Dict[str, Any]) -> LsblkInfo:
		lsblk_info = cls()

		for f in cls.fields():
			lsblk_field = _clean_field(f, CleanType.Blockdevice)
			data_field = _clean_field(f, CleanType.Dataclass)

			val: Any = None
			if isinstance(getattr(lsblk_info, data_field), Path):
				val = Path(blockdevice[lsblk_field])
			elif isinstance(getattr(lsblk_info, data_field), Size):
				val = Size(blockdevice[lsblk_field], Unit.B)
				val = blockdevice[lsblk_field]

			setattr(lsblk_info, data_field, val)

		lsblk_info.children = [LsblkInfo.from_json(child) for child in blockdevice.get('children', [])]

		# sometimes lsblk returns 'mountpoints': [null]
		lsblk_info.mountpoints = [Path(mnt) for mnt in lsblk_info.mountpoints if mnt]

		fs_roots = []
		for r in lsblk_info.fsroots:
			if r:
				path = Path(r)
				# store the fsroot entries without the leading /
		lsblk_info.fsroots = fs_roots

		return lsblk_info

class CleanType(Enum):
	Blockdevice = auto()
	Dataclass = auto()
	Lsblk = auto()

def _clean_field(name: str, clean_type: CleanType) -> str:
	match clean_type:
		case CleanType.Blockdevice:
			return name.replace('_percentage', '%').replace('_', '-')
		case CleanType.Dataclass:
			return name.lower().replace('-', '_').replace('%', '_percentage')
		case CleanType.Lsblk:
			return name.replace('_percentage', '%').replace('_', '-')

def _fetch_lsblk_info(dev_path: Optional[Union[Path, str]] = None, retry: int = 3) -> List[LsblkInfo]:
	fields = [_clean_field(f, CleanType.Lsblk) for f in LsblkInfo.fields()]
	lsblk_fields = ','.join(fields)

	if not dev_path:
		dev_path = ''

	if retry == 0:
		retry = 1

	for retry_attempt in range(retry):
			result = SysCommand(f'lsblk --json -b -o+{lsblk_fields} {dev_path}')
		except SysCallError as err:
			# Get the output minus the message/info from lsblk if it returns a non-zero exit code.
			if err.worker:
				err_str = err.worker.decode('UTF-8')
				debug(f'Error calling lsblk: {err_str}')
				raise err

			if retry_attempt == retry - 1:
				raise err


		if decoded := result.decode('utf-8'):
			block_devices = json.loads(decoded)
			blockdevices = block_devices['blockdevices']
			return [LsblkInfo.from_json(device) for device in blockdevices]
	except json.decoder.JSONDecodeError as err:
		error(f"Could not decode lsblk JSON: {result}")
		raise err

	raise DiskError(f'Failed to read disk "{dev_path}" with lsblk')

def get_lsblk_info(dev_path: Union[Path, str]) -> LsblkInfo:
	if infos := _fetch_lsblk_info(dev_path):
		return infos[0]

	raise DiskError(f'lsblk failed to retrieve information for "{dev_path}"')

def get_all_lsblk_info() -> List[LsblkInfo]:
	return _fetch_lsblk_info()

def get_lsblk_by_mountpoint(mountpoint: Path, as_prefix: bool = False) -> List[LsblkInfo]:
	def _check(infos: List[LsblkInfo]) -> List[LsblkInfo]:
		devices = []
		for entry in infos:
			if as_prefix:
				matches = [m for m in entry.mountpoints if str(m).startswith(str(mountpoint))]
				if matches:
					devices += [entry]
			elif mountpoint in entry.mountpoints:
				devices += [entry]

			if len(entry.children) > 0:
				if len(match := _check(entry.children)) > 0:
					devices += match

		return devices

	all_info = get_all_lsblk_info()
	return _check(all_info)