Merge "Add typing (2/3)"

This commit is contained in:
Zuul
2025-09-15 14:13:54 +00:00
committed by Gerrit Code Review
4 changed files with 159 additions and 118 deletions

View File

@@ -21,7 +21,7 @@ from oslo_utils.imageutils import format_inspector
from oslo_utils.version import version_info
def main():
def main() -> None:
"""Run image security checks and give feedback.
Runs the image format detector and related security checks against
@@ -83,6 +83,10 @@ def main():
sys.exit(1)
inspector = format_inspector.detect_file_format(image)
if inspector is None:
print('Could not find format inspector for image', file=sys.stderr)
sys.exit(1)
safe = True
try:
inspector.safety_check()

View File

@@ -22,8 +22,10 @@ complex-format images.
"""
import abc
from collections.abc import Callable, Generator
import struct
from typing import cast, TypedDict
from typing import cast, Any, IO, TypedDict
from collections.abc import Iterator
import logging
from oslo_utils._i18n import _
@@ -32,7 +34,9 @@ from oslo_utils import units
LOG = logging.getLogger(__name__)
def _chunked_reader(fileobj, chunk_size=512):
def _chunked_reader(
fileobj: IO[bytes] | 'InspectWrapper', chunk_size: int = 512
) -> Generator[bytes, None, None]:
while True:
chunk = fileobj.read(chunk_size)
if not chunk:
@@ -57,21 +61,23 @@ class CaptureRegion:
variable data structures.
"""
def __init__(self, offset, length, min_length=None):
def __init__(
self, offset: int, length: int, min_length: int | None = None
) -> None:
self.offset = offset
self.length = length
self.data = b''
self.data: bytes | bytearray = b''
self.min_length = min_length
@property
def complete(self):
def complete(self) -> bool:
"""Returns True when we have captured the desired data."""
if self.min_length is not None:
return self.min_length <= len(self.data)
else:
return self.length == len(self.data)
def capture(self, chunk, current_position):
def capture(self, chunk: bytes, current_position: int) -> bytes:
"""Process a chunk of data.
This should be called for each chunk in the read loop, at least
@@ -94,6 +100,7 @@ class CaptureRegion:
lead_gap = 0
self.data += chunk[lead_gap:]
self.data = self.data[: self.length]
return b''
class EndCaptureRegion(CaptureRegion):
@@ -107,22 +114,23 @@ class EndCaptureRegion(CaptureRegion):
will also be the region length)
"""
def __init__(self, offset):
def __init__(self, offset: int) -> None:
super().__init__(offset, offset)
# We don't want to indicate completeness until we have the data we
# want *and* have reached EOF
self._complete = False
def capture(self, chunk, current_position):
def capture(self, chunk: bytes, current_position: int) -> bytes:
self.data += chunk
self.data = self.data[0 - self.length :]
self.offset = current_position - len(self.data)
return b''
@property
def complete(self):
def complete(self) -> bool:
return super().complete and self._complete
def finish(self):
def finish(self) -> None:
"""Indicate that the entire stream has been read."""
self._complete = True
@@ -130,7 +138,12 @@ class EndCaptureRegion(CaptureRegion):
class SafetyCheck:
"""Represents a named safety check on an inspector"""
def __init__(self, name, target_fn, description=None):
def __init__(
self,
name: str,
target_fn: Callable[[], None],
description: str | None = None,
) -> None:
"""A safety check, it's meta info, and result.
:param name: Should be a short name of the check (ideally no spaces)
@@ -143,7 +156,7 @@ class SafetyCheck:
self.target_fn = target_fn
self.description = description
def __call__(self):
def __call__(self) -> None:
"""Executes the target check function, records the result.
:raises SafetyViolation: If an error check fails
@@ -162,7 +175,7 @@ class SafetyCheck:
raise SafetyViolation(_('Unexpected error'))
@classmethod
def null(cls):
def null(cls) -> 'SafetyCheck':
"""The "null" safety check always returns True.
This should only be used if there is no meaningful checks that can
@@ -175,7 +188,7 @@ class SafetyCheck:
)
@classmethod
def banned(cls):
def banned(cls) -> 'SafetyCheck':
"""The "banned" safety check always returns False.
This should be used for formats we want to identify but never allow,
@@ -183,7 +196,7 @@ class SafetyCheck:
we are unable to check for safety.
"""
def fail():
def fail() -> None:
raise SafetyViolation(_('This file format is not allowed'))
return cls('banned', fail, _('This file format is not allowed'))
@@ -204,7 +217,7 @@ class SafetyViolation(Exception):
class SafetyCheckFailed(Exception):
"""Indictes that one or more of a series of safety checks failed."""
def __init__(self, failures):
def __init__(self, failures: dict[str, Any]) -> None:
super().__init__(
_('Safety checks failed: %s') % ','.join(failures.keys())
)
@@ -224,7 +237,7 @@ class FileInspector(abc.ABC):
# This should match what qemu-img thinks this format is
NAME = ''
def __init__(self, tracing=False):
def __init__(self, tracing: bool = False) -> None:
self._total_count = 0
# NOTE(danms): The logging in here is extremely verbose for a reason,
@@ -243,18 +256,18 @@ class FileInspector(abc.ABC):
'All inspectors must define at least one safety check'
)
def _trace(self, *args, **kwargs):
def _trace(self, *args: Any, **kwargs: Any) -> None:
if self._tracing:
LOG.debug(*args, **kwargs)
@abc.abstractmethod
def _initialize(self):
def _initialize(self) -> None:
"""Set up inspector before we start processing data.
This should add the initial set of capture regions and safety checks.
"""
def finish(self):
def finish(self) -> None:
"""Indicate that the entire stream has been read.
This should be called when the entire stream has been completely read,
@@ -265,7 +278,7 @@ class FileInspector(abc.ABC):
if isinstance(region, EndCaptureRegion):
region.finish()
def _capture(self, chunk, only=None):
def _capture(self, chunk: bytes, only: list[str] | None = None) -> None:
if self._finished:
raise RuntimeError(
'Inspector has been marked finished, '
@@ -277,7 +290,7 @@ class FileInspector(abc.ABC):
if isinstance(region, EndCaptureRegion) or not region.complete:
region.capture(chunk, self._total_count)
def eat_chunk(self, chunk):
def eat_chunk(self, chunk: bytes) -> None:
"""Call this to present chunks of the file to the inspector."""
pre_regions = set(self._capture_regions.values())
pre_complete = {
@@ -313,7 +326,7 @@ class FileInspector(abc.ABC):
for region in post_complete - pre_complete:
self.region_complete(self.region_name(region))
def post_process(self):
def post_process(self) -> None:
"""Post-read hook to process what has been read so far.
This will be called after each chunk is read and potentially captured
@@ -323,36 +336,36 @@ class FileInspector(abc.ABC):
"""
pass
def region(self, name):
def region(self, name: str) -> CaptureRegion:
"""Get a CaptureRegion by name."""
return self._capture_regions[name]
def region_name(self, region):
def region_name(self, region: CaptureRegion) -> str:
"""Return the region name for a region object."""
for name in self._capture_regions:
if self._capture_regions[name] is region:
return name
raise ValueError('No such region')
def new_region(self, name, region):
def new_region(self, name: str, region: CaptureRegion) -> None:
"""Add a new CaptureRegion by name."""
if self.has_region(name):
# This is a bug, we tried to add the same region twice
raise ImageFormatError(f'Inspector re-added region {name}')
self._capture_regions[name] = region
def has_region(self, name):
def has_region(self, name: str) -> bool:
"""Returns True if named region has been defined."""
return name in self._capture_regions
def delete_region(self, name):
def delete_region(self, name: str) -> None:
"""Remove a capture region by name.
This will raise KeyError if the region does not exist.
"""
del self._capture_regions[name]
def region_complete(self, region_name):
def region_complete(self, region_name: str) -> None:
"""Called when a region becomes complete.
Subclasses may implement this if they need to do one-time processing
@@ -360,7 +373,7 @@ class FileInspector(abc.ABC):
"""
pass
def add_safety_check(self, check):
def add_safety_check(self, check: SafetyCheck) -> None:
if not isinstance(check, SafetyCheck):
raise RuntimeError(
_('Unable to add safety check of type %s')
@@ -372,16 +385,16 @@ class FileInspector(abc.ABC):
@property
@abc.abstractmethod
def format_match(self):
def format_match(self) -> bool:
"""Returns True if the file appears to be the expected format."""
@property
def virtual_size(self):
def virtual_size(self) -> int:
"""Returns the virtual size of the disk image, or zero if unknown."""
return self._total_count
@property
def actual_size(self):
def actual_size(self) -> int:
"""Returns the total size of the file, usually smaller than
virtual_size. NOTE: this will only be accurate if the entire
file is read and processed.
@@ -389,16 +402,16 @@ class FileInspector(abc.ABC):
return self._total_count
@property
def complete(self):
def complete(self) -> bool:
"""Returns True if we have all the information needed."""
return all(r.complete for r in self._capture_regions.values())
def __str__(self):
def __str__(self) -> str:
"""The string name of this file format."""
return self.NAME
@property
def context_info(self):
def context_info(self) -> dict[str, int]:
"""Return info on amount of data held in memory for auditing.
This is a dict of region:sizeinbytes items that the inspector
@@ -410,7 +423,7 @@ class FileInspector(abc.ABC):
}
@classmethod
def from_file(cls, filename):
def from_file(cls, filename: str) -> 'FileInspector':
"""Read as much of a file as necessary to complete inspection.
NOTE: Because we only read as much of the file as necessary, the
@@ -431,7 +444,7 @@ class FileInspector(abc.ABC):
raise ImageFormatError('File is not in requested format')
return inspector
def safety_check(self):
def safety_check(self) -> None:
"""Perform all checks to determine if this file is safe.
:raises ImageFormatError: If safety cannot be guaranteed because of
@@ -470,12 +483,12 @@ class FileInspector(abc.ABC):
class RawFileInspector(FileInspector):
NAME = 'raw'
def _initialize(self):
def _initialize(self) -> None:
"""Raw files have nothing to capture and no safety checks."""
self.add_safety_check(SafetyCheck.null())
@property
def format_match(self):
def format_match(self) -> bool:
# By definition, raw files are unformatted and thus we always match
return True
@@ -520,7 +533,7 @@ class QcowInspector(FileInspector):
I_FEATURES_DATAFILE_BIT = 3
I_FEATURES_MAX_BIT = 4
def _initialize(self):
def _initialize(self) -> None:
self.qemu_header_info: QEMUHeader = {}
self.new_region('header', CaptureRegion(0, 512))
self.add_safety_check(
@@ -531,7 +544,7 @@ class QcowInspector(FileInspector):
SafetyCheck('unknown_features', self.check_unknown_features)
)
def region_complete(self, region):
def region_complete(self, region_name: str) -> None:
self.qemu_header_info = cast(
QEMUHeader,
dict(
@@ -552,16 +565,16 @@ class QcowInspector(FileInspector):
self.qemu_header_info = {}
@property
def virtual_size(self):
def virtual_size(self) -> int:
return self.qemu_header_info.get('size', 0)
@property
def format_match(self):
def format_match(self) -> bool:
if not self.region('header').complete:
return False
return self.qemu_header_info.get('magic') == b'QFI\xfb'
def check_backing_file(self):
def check_backing_file(self) -> None:
bf_offset_bytes = self.region('header').data[
self.BF_OFFSET : self.BF_OFFSET + self.BF_OFFSET_LEN
]
@@ -570,7 +583,7 @@ class QcowInspector(FileInspector):
if bf_offset != 0:
raise SafetyViolation('Image has a backing file')
def check_unknown_features(self):
def check_unknown_features(self) -> None:
ver = self.qemu_header_info.get('version')
if ver == 2:
# Version 2 did not have the feature flag array, so no need to
@@ -611,7 +624,7 @@ class QcowInspector(FileInspector):
)
raise SafetyViolation('Unknown QCOW2 features found')
def check_data_file(self):
def check_data_file(self) -> None:
i_features = self.region('header').data[
self.I_FEATURES : self.I_FEATURES + self.I_FEATURES_LEN
]
@@ -627,14 +640,14 @@ class QcowInspector(FileInspector):
class QEDInspector(FileInspector):
NAME = 'qed'
def _initialize(self):
def _initialize(self) -> None:
self.new_region('header', CaptureRegion(0, 512))
# QED format is not supported by anyone, but we want to detect it
# and mark it as just always unsafe.
self.add_safety_check(SafetyCheck.banned())
@property
def format_match(self):
def format_match(self) -> bool:
if not self.region('header').complete:
return False
return self.region('header').data.startswith(b'QED\x00')
@@ -658,23 +671,25 @@ class VHDInspector(FileInspector):
NAME = 'vhd'
def _initialize(self):
def _initialize(self) -> None:
self.new_region('header', CaptureRegion(0, 512))
self.add_safety_check(SafetyCheck.null())
@property
def format_match(self):
def format_match(self) -> bool:
return self.region('header').data.startswith(b'conectix')
@property
def virtual_size(self):
def virtual_size(self) -> int:
if not self.region('header').complete:
return 0
if not self.format_match:
return 0
return struct.unpack('>Q', self.region('header').data[40:48])[0]
return cast(
int, struct.unpack('>Q', self.region('header').data[40:48])[0]
)
# The VHDX format consists of a complex dynamic little-endian
@@ -749,12 +764,12 @@ class VHDXInspector(FileInspector):
VIRTUAL_DISK_SIZE = '2FA54224-CD1B-4876-B211-5DBED83BF4B8'
VHDX_METADATA_TABLE_MAX_SIZE = 32 * 2048 # From qemu
def _initialize(self):
def _initialize(self) -> None:
self.new_region('ident', CaptureRegion(0, 32))
self.new_region('header', CaptureRegion(192 * 1024, 64 * 1024))
self.add_safety_check(SafetyCheck.null())
def post_process(self):
def post_process(self) -> None:
# After reading a chunk, we may have the following conditions:
#
# 1. We may have just completed the header region, and if so,
@@ -775,18 +790,18 @@ class VHDXInspector(FileInspector):
self.new_region('vds', region)
@property
def format_match(self):
def format_match(self) -> bool:
return self.region('ident').data.startswith(b'vhdxfile')
@staticmethod
def _guid(buf):
def _guid(buf: bytes | bytearray) -> str:
"""Format a MSFT GUID from the 16-byte input buffer."""
guid_format = '<IHHBBBBBBBB'
return '{:08X}-{:04X}-{:04X}-{:02X}{:02X}-{:02X}{:02X}{:02X}{:02X}{:02X}{:02X}'.format(
*struct.unpack(guid_format, buf)
)
def _find_meta_region(self):
def _find_meta_region(self) -> CaptureRegion | None:
# The region table entries start after a 16-byte table header
region_entry_first = 16
@@ -836,7 +851,7 @@ class VHDXInspector(FileInspector):
self._trace('Did not find metadata region')
return None
def _find_meta_entry(self, desired_guid):
def _find_meta_entry(self, desired_guid: str) -> CaptureRegion | None:
meta_buffer = self.region('metadata').data
if len(meta_buffer) < 32:
# Not enough data yet for full header
@@ -886,14 +901,14 @@ class VHDXInspector(FileInspector):
return None
@property
def virtual_size(self):
def virtual_size(self) -> int:
# Until we have found the offset and have enough metadata buffered
# to read it, return "unknown"
if not self.has_region('vds') or not self.region('vds').complete:
return 0
(size,) = struct.unpack('<Q', self.region('vds').data)
return size
return cast(int, size)
# The VMDK format comes in a large number of variations, but the
@@ -944,8 +959,8 @@ class VMDKInspector(FileInspector):
MARKER_EOS = 0
MARKER_FOOTER = 3
def _initialize(self):
self.desc_text = None
def _initialize(self) -> None:
self.desc_text: str | None = None
# This is the header for "Hosted Sparse Extent" type files. It may
# or may not be used, depending on what kind of VMDK we are about to
# read.
@@ -961,7 +976,9 @@ class VMDKInspector(FileInspector):
)
self.add_safety_check(SafetyCheck('descriptor', self.check_descriptor))
def _parse_sparse_header(self, region, offset=0):
def _parse_sparse_header(
self, region: str, offset: int = 0
) -> tuple[bytes, int, int, int, int]:
(
sig,
ver,
@@ -979,7 +996,7 @@ class VMDKInspector(FileInspector):
)
return sig, ver, desc_sec, desc_num, gdOffset
def post_process(self):
def post_process(self) -> None:
# If we have just completed the header region, we need to calculate
# the location and length of the descriptor, which should immediately
# follow and may have been partially-read in this read. If the header
@@ -1039,11 +1056,11 @@ class VMDKInspector(FileInspector):
'descriptor', CaptureRegion(desc_offset, desc_size)
)
def region_complete(self, region_name):
def region_complete(self, region_name: str) -> None:
if region_name == 'descriptor':
self._parse_descriptor()
def _parse_descriptor(self):
def _parse_descriptor(self) -> None:
try:
# The sparse descriptor is null-padded to 512 bytes. Find the
# first one and use it as the end of the text string.
@@ -1079,14 +1096,14 @@ class VMDKInspector(FileInspector):
self.vmdktype = vmdktype
@property
def format_match(self):
def format_match(self) -> bool:
if self.has_region('header'):
return self.region('header').data.startswith(b'KDMV')
else:
return self.vmdktype != 'formatnotfound'
@property
def virtual_size(self):
def virtual_size(self) -> int:
if not self.desc_text:
# Not enough data yet
return 0
@@ -1095,6 +1112,8 @@ class VMDKInspector(FileInspector):
LOG.warning('Unsupported VMDK format %r', self.vmdktype)
return 0
sectors: int
# If we have the descriptor, we definitely have the header
_sig, _ver, _flags, sectors, _grain, _desc_sec, _desc_num = (
struct.unpack('<IIIQQQQ', self.region('header').data[:44])
@@ -1102,7 +1121,7 @@ class VMDKInspector(FileInspector):
return sectors * 512
def check_descriptor(self):
def check_descriptor(self) -> None:
if not self.desc_text:
raise SafetyViolation(_('No descriptor found'))
@@ -1148,7 +1167,7 @@ class VMDKInspector(FileInspector):
LOG.error('VMDK file specified no extents')
raise SafetyViolation(_('No extents found'))
def check_footer(self):
def check_footer(self) -> None:
h_sig, h_ver, h_desc_sec, h_desc_num, h_goff = (
self._parse_sparse_header('header')
)
@@ -1200,27 +1219,30 @@ class VDIInspector(FileInspector):
NAME = 'vdi'
def _initialize(self):
def _initialize(self) -> None:
self.new_region('header', CaptureRegion(0, 512))
self.add_safety_check(SafetyCheck.null())
@property
def format_match(self):
def format_match(self) -> bool:
if not self.region('header').complete:
return False
signature: int
(signature,) = struct.unpack(
'<I', self.region('header').data[0x40:0x44]
)
return signature == 0xBEDA107F
@property
def virtual_size(self):
def virtual_size(self) -> int:
if not self.region('header').complete:
return 0
if not self.format_match:
return 0
size: int
(size,) = struct.unpack('<Q', self.region('header').data[0x170:0x178])
return size
@@ -1256,20 +1278,20 @@ class ISOInspector(FileInspector):
NAME = 'iso'
def _initialize(self):
def _initialize(self) -> None:
self.new_region('system_area', CaptureRegion(0, 32 * units.Ki))
self.new_region('header', CaptureRegion(32 * units.Ki, 2 * units.Ki))
self.add_safety_check(SafetyCheck.null())
@property
def format_match(self):
def format_match(self) -> bool:
if not self.complete:
return False
signature = self.region('header').data[1:6]
return signature in (b'CD001', b'NSR02', b'NSR03')
@property
def virtual_size(self):
def virtual_size(self) -> int:
if not self.complete:
return 0
if not self.format_match:
@@ -1287,6 +1309,10 @@ class ISOInspector(FileInspector):
descriptor_type = self.region('header').data[0]
if descriptor_type != 1:
return 0
logical_block_size_data: bytes | bytearray
logical_block_size: int
# The size in bytes of a logical block is stored at offset 128
# and is 2 bytes long encoded in both little and big endian
# int16_LSB-MSB so the field is 4 bytes long
@@ -1298,6 +1324,10 @@ class ISOInspector(FileInspector):
(logical_block_size,) = struct.unpack(
'<H', logical_block_size_data[:2]
)
volume_space_size_data: bytes | bytearray
volume_space_size: int
# The volume space size is the total number of logical blocks
# and is stored at offset 80 and is 8 bytes long
# as with the logical block size the field is encoded in both
@@ -1320,7 +1350,7 @@ class GPTInspector(FileInspector):
MBR_PTE_START = 446
MEDIA_TYPE_FDISK = 0xF8
def _initialize(self):
def _initialize(self) -> None:
self.new_region('mbr', CaptureRegion(0, 512))
# TODO(danms): If we start inspecting the contents of the GPT
# structures themselves, we need to realize that they are block-aligned
@@ -1332,7 +1362,7 @@ class GPTInspector(FileInspector):
# self.new_region('gpt_backup', EndCaptureRegion(512))
self.add_safety_check(SafetyCheck('mbr', self.check_mbr_partitions))
def _check_for_fat(self):
def _check_for_fat(self) -> bool:
# A FAT filesystem looks like an MBR, but actually starts with a VBR,
# which has the same signature as an MBR, but with more specifics in
# the BPB (BIOS Parameter Block).
@@ -1346,7 +1376,7 @@ class GPTInspector(FileInspector):
return num_fats == 2 and media_desc == self.MEDIA_TYPE_FDISK
@property
def format_match(self):
def format_match(self) -> bool:
if not self.region('mbr').complete:
return False
# Check to see if this looks like a VBR from a FAT filesystem so we
@@ -1355,7 +1385,7 @@ class GPTInspector(FileInspector):
(mbr_sig,) = struct.unpack('<H', self.region('mbr').data[510:512])
return mbr_sig == self.MBR_SIGNATURE and not is_fat
def check_mbr_partitions(self):
def check_mbr_partitions(self) -> None:
valid_partitions = []
found_gpt = False
for i in range(4):
@@ -1409,12 +1439,12 @@ class LUKSHeader(TypedDict, total=False):
class LUKSInspector(FileInspector):
NAME = 'luks'
def _initialize(self):
def _initialize(self) -> None:
self.new_region('header', CaptureRegion(0, 592))
self.add_safety_check(SafetyCheck('version', self.check_version))
@property
def format_match(self):
def format_match(self) -> bool:
return self.region('header').data[:6] == b'LUKS\xba\xbe'
@property
@@ -1432,7 +1462,7 @@ class LUKSInspector(FileInspector):
]
return cast(LUKSHeader, dict(zip(names, fields)))
def check_version(self):
def check_version(self) -> None:
header = self.header_items
if header['version'] != 1:
raise SafetyViolation(
@@ -1440,7 +1470,7 @@ class LUKSInspector(FileInspector):
)
@property
def virtual_size(self):
def virtual_size(self) -> int:
# NOTE(danms): This will not be correct until/unless the whole stream
# has been read, since all we have is (effectively the size of the
# header. This is similar to how RawFileInspector works.
@@ -1468,10 +1498,15 @@ class InspectWrapper:
the detected formats to some smaller scope.
"""
def __init__(self, source, expected_format=None, allowed_formats=None):
def __init__(
self,
source: IO[bytes],
expected_format: str | None = None,
allowed_formats: list[str] | None = None,
) -> None:
self._source = source
self._expected_format = expected_format
self._errored_inspectors = set()
self._errored_inspectors: set[FileInspector] = set()
self._inspectors = {
v()
for k, v in ALL_FORMATS.items()
@@ -1479,10 +1514,10 @@ class InspectWrapper:
}
self._finished = False
def __iter__(self):
def __iter__(self) -> Iterator[bytes]:
return self
def _process_chunk(self, chunk):
def _process_chunk(self, chunk: bytes) -> None:
for inspector in [
i for i in self._inspectors if i not in self._errored_inspectors
]:
@@ -1520,7 +1555,7 @@ class InspectWrapper:
f'Content does not match expected format {inspector.NAME!r}'
)
def __next__(self):
def __next__(self) -> bytes:
try:
chunk = next(self._source)
except StopIteration:
@@ -1529,23 +1564,23 @@ class InspectWrapper:
self._process_chunk(chunk)
return chunk
def read(self, size):
def read(self, size: int) -> bytes:
chunk = self._source.read(size)
self._process_chunk(chunk)
return chunk
def _finish(self):
def _finish(self) -> None:
for inspector in self._inspectors:
inspector.finish()
self._finished = True
def close(self):
def close(self) -> None:
if hasattr(self._source, 'close'):
self._source.close()
self._finish()
@property
def formats(self):
def formats(self) -> list[FileInspector] | None:
"""The formats (potentially multiple) determined from the content.
This is just like format, but returns a list of formats that matched,
@@ -1580,7 +1615,7 @@ class InspectWrapper:
return matches
@property
def format(self):
def format(self) -> FileInspector | None:
"""The format determined from the content.
If this is None, a decision has not been reached. Otherwise,
@@ -1628,7 +1663,7 @@ ALL_FORMATS: dict[str, type[FileInspector]] = {
}
def get_inspector(format_name):
def get_inspector(format_name: str) -> type[FileInspector] | None:
"""Returns a FormatInspector class based on the given name.
:param format_name: The name of the disk_format (raw, qcow2, etc).
@@ -1638,7 +1673,7 @@ def get_inspector(format_name):
return ALL_FORMATS.get(format_name)
def detect_file_format(filename):
def detect_file_format(filename: str) -> FileInspector | None:
"""Attempts to detect the format of a file.
This runs through a file one time, running all the known inspectors in

View File

@@ -27,6 +27,7 @@ Helper methods to deal with images.
import json
import re
from typing import Any
import debtcollector
@@ -61,7 +62,14 @@ class QemuImgInfo:
re.I,
)
def __init__(self, cmd_output=None, format='human'):
def __init__(
self,
cmd_output: str | bytes | bytearray | None = None,
format: str = 'human',
) -> None:
if isinstance(cmd_output, bytes | bytearray):
cmd_output = cmd_output.decode()
if format == 'json':
details = json.loads(cmd_output or '{}')
self.image = details.get('filename')
@@ -94,7 +102,7 @@ class QemuImgInfo:
self.encrypted = details.get('encrypted')
self.format_specific = None
def __str__(self):
def __str__(self) -> str:
lines = [
f'image: {self.image}',
f'file_format: {self.file_format}',
@@ -112,7 +120,7 @@ class QemuImgInfo:
lines.append(f"format_specific: {self.format_specific}")
return "\n".join(lines)
def _canonicalize(self, field):
def _canonicalize(self, field: str) -> str:
# Standardize on underscores/lc/no dash and no spaces
# since qemu seems to have mixed outputs here... and
# this format allows for better integration with python
@@ -122,7 +130,7 @@ class QemuImgInfo:
field = field.replace(c, '_')
return field
def _extract_bytes(self, details):
def _extract_bytes(self, details: str) -> int:
# Replace it with the byte amount
real_size = self.SIZE_RE.search(details)
if not real_size:
@@ -139,12 +147,16 @@ class QemuImgInfo:
# Allow abbreviated unit such as K to mean KB for compatibility.
if len(unit_of_measure) == 1 and unit_of_measure != 'B':
unit_of_measure += 'B'
return strutils.string_to_bytes(
f'{magnitude}{unit_of_measure}', return_int=True
return int(
strutils.string_to_bytes(
f'{magnitude}{unit_of_measure}', return_int=True
)
)
def _extract_details(self, root_cmd, root_details, lines_after):
real_details = root_details
def _extract_details(
self, root_cmd: str, root_details: str, lines_after: list[str]
) -> str | int | list[dict[str, str]]:
real_details: str | int | list[dict[str, str]] = root_details
if root_cmd == 'backing_file':
# Replace it with the real backing file
backing_match = self.BACKING_FILE_RE.match(root_details)
@@ -157,7 +169,7 @@ class QemuImgInfo:
else:
real_details = self._extract_bytes(root_details)
elif root_cmd == 'file_format':
real_details = real_details.strip().lower()
real_details = root_details.strip().lower()
elif root_cmd == 'snapshot_list':
# Next line should be a header, starting with 'ID'
if not lines_after or not lines_after.pop(0).startswith("ID"):
@@ -189,7 +201,7 @@ class QemuImgInfo:
)
return real_details
def _parse(self, cmd_output):
def _parse(self, cmd_output: str) -> dict[str, Any]:
# Analysis done of qemu-img.c to figure out what is going on here
# Find all points start with some chars and then a ':' then a newline
# and then handle the results of those 'top level' items in a separate

View File

@@ -42,8 +42,6 @@ show_error_context = true
strict = true
# debtcollector is untyped (for now)
disallow_untyped_decorators = false
# some of our own functional are untyped (for now)
disallow_untyped_calls = false
exclude = '''
(?x)(
doc
@@ -51,14 +49,6 @@ exclude = '''
)
'''
[[tool.mypy.overrides]]
module = [
"oslo_utils.imageutils.*",
]
disallow_untyped_calls = false
disallow_untyped_defs = false
disallow_subclassing_any = false
[[tool.mypy.overrides]]
module = [
"oslo_utils.fixture",