Skip to content

Commit 414734f

Browse files
committed
Change DataStruct to work with metaclass
1 parent 90661c2 commit 414734f

File tree

2 files changed

+293
-197
lines changed

2 files changed

+293
-197
lines changed

quickstruct/quickstruct.py

Lines changed: 141 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,177 +1,172 @@
11
from .common import *
2-
from .common import Padding
3-
from .struct_builder import FieldInfo, FixedStructBuilder, DynamicStructBuilder, IStructBuilder, StructFlags
2+
from .common import TypeMeta
3+
from .error import InheritanceError, UnoverridbaleFieldError, OverrideError, SizeError, UnsafeOverrideError
4+
from .field import FieldFlags, StructField, StructPaddingField
5+
from .struct_builder import FieldInfo, StructBuilder, StructFlags
46

57
from io import BytesIO
6-
import pprint
78
import typing
89

910

10-
class StructField:
11-
_type: typing.Type[Type]
12-
_name: str
11+
class DataStructMeta(TypeMeta):
12+
__sflags__: StructFlags
13+
__size__: int
14+
__fields__: typing.Dict[str, FieldInfo]
15+
16+
@property
17+
def fields(self) -> typing.List[FieldInfo]:
18+
return list(self.__fields__.values())
19+
20+
@property
21+
def is_final(self) -> bool:
22+
"""
23+
Returns True if this struct is final.
24+
"""
25+
return bool(self.__sflags__ & StructFlags.Final)
26+
27+
@property
28+
def is_protected(self) -> bool:
29+
"""
30+
Returns True if the struct is protected.
31+
"""
32+
return bool(self.__sflags__ & StructFlags.Protected)
33+
34+
@property
35+
def alignment(self) -> int:
36+
"""
37+
Returns the alignment of the struct (which is the size of the struct).
38+
"""
39+
return self.size
40+
41+
@property
42+
def size(self) -> int:
43+
"""
44+
Returns the size of the struct. If the struct has dynamic size, this will return -1.
45+
"""
46+
return self.__size__
47+
48+
@property
49+
def is_dynamic_size(self):
50+
return self.size == -1
51+
52+
@property
53+
def is_fixed_size(self):
54+
return not self.is_dynamic_size
1355

14-
def __init__(self, typ: typing.Type[Type]) -> None:
15-
self._type = typ
56+
@classmethod
57+
def __prepare__(cls, _, bases, flags: StructFlags = StructFlags.Default, **__):
58+
namespace = {}
1659

17-
def __set_name__(self, owner, name):
18-
self._name = f"<field>_{name}"
19-
setattr(owner, name, self)
60+
fields: typing.List[FieldInfo] = []
2061

21-
def __get__(self, instance, _):
22-
if instance is None:
23-
return self
24-
return getattr(instance, self._name)
62+
for base in bases:
63+
if not isinstance(base, DataStructMeta):
64+
continue
65+
if base.is_final:
66+
raise InheritanceError(f"Cannot inherit from a final struct '{base.__name__}'.")
67+
fields.extend(base.__fields__.values())
2568

26-
def __set__(self, instance, value):
27-
if not isinstance(value, self._type):
28-
raise TypeError(f"Expected {self._type}, got {type(value)}")
29-
setattr(instance, self._name, value)
69+
namespace['__fields__'] = fields
70+
namespace['__sflags__'] = flags
3071

72+
return namespace
3173

32-
class DataStruct(Type):
33-
"""Base type for all data structures."""
74+
def __new__(cls, name, bases, namespace, flags: StructFlags = StructFlags.Default, **_):
75+
base_fields: typing.List[FieldInfo] = namespace['__fields__']
76+
flags: StructFlags = namespace['__sflags__']
77+
78+
try:
79+
annotations = namespace['__annotations__']
80+
except KeyError:
81+
annotations = {}
82+
83+
fields: typing.Dict[str, FieldInfo] = {}
84+
85+
defaults = {n: v for n, v in namespace.items() if n in annotations}
86+
namespace = {n: v for n, v in namespace.items() if n not in annotations}
87+
88+
field_flags: FieldFlags = FieldFlags.NONE
89+
if flags & StructFlags.Protected:
90+
field_flags |= FieldFlags.Protected
91+
92+
for field in base_fields + [FieldInfo(name, typ, field_flags) for name, typ in annotations.items()]:
93+
if field.name in fields:
94+
if fields[field.name].is_protected:
95+
raise UnoverridbaleFieldError(f"Field '{field.name}' is protected and cannot be overridden.")
96+
if flags & StructFlags.TypeSafeOverride and field.type != fields[field.name].type:
97+
raise UnsafeOverrideError(f"Field '{field.name}' is overriding a field from a base type with a different type.")
98+
if not flags & StructFlags.AllowOverride:
99+
raise OverrideError(f"Field '{field.name}' is overriding a field from a base type.")
100+
fields[field.name] = field
101+
102+
builder = StructBuilder()
103+
for field in fields.values():
104+
builder.add_field(field)
105+
if flags & StructFlags.ReorderFields:
106+
builder.reorder_fields()
107+
if not flags & StructFlags.NoAlignment:
108+
builder.align_fields(None if flags & StructFlags.AlignAuto else (1 << flags & StructFlags.AlignmentMask))
109+
fields = {field.name: field for field in builder.build()}
110+
111+
if flags & StructFlags.ForceFixedSize and builder.is_dynamic_size:
112+
raise SizeError(f"Cannot force fixed size on a dynamic struct '{name}'.")
113+
114+
namespace['__slots__'] = tuple(f"__field_{field}__" for field in fields.keys())
115+
namespace['__defaults__'] = defaults
116+
namespace['__fields__'] = fields
117+
namespace['__size__'] = builder.size
34118

35-
__s_fields__: typing.Dict[str, typing.Type[Type]]
36-
__s_flags__: StructFlags = StructFlags.Default
37-
__s_size__: int
38-
__s_alignment__: int
119+
namespace.update({
120+
field.name: StructField(field.type) if not issubclass(field.type, Padding) else StructPaddingField()
121+
for field in fields.values()
122+
})
123+
124+
return super().__new__(cls, name, bases, namespace, **_)
125+
126+
def __iter__(self):
127+
return iter(self.__fields__.values())
128+
129+
def __repr__(self) -> str:
130+
if not self.__fields__:
131+
return "{}"
132+
return str({field.name: field.type for field in self.__fields__.values()})
133+
134+
135+
class DataStruct(Type, metaclass=DataStructMeta):
136+
"""Base type for all data structures."""
39137

40138
def __init__(self, **kwargs) -> None:
41139
super().__init__()
42140
try:
43-
for name, value in kwargs.items():
44-
# if not isinstance(value, self.__fields__[name]):
45-
# raise TypeError(f"Expected {self.__fields__[name]}, got {type(value)}")
141+
for name in type(self).__fields__:
142+
if name in kwargs:
143+
value = kwargs[name]
144+
elif name in type(self).__defaults__:
145+
value = type(self).__defaults__[name]
146+
else:
147+
continue
46148
setattr(self, name, value)
47149
except KeyError:
48150
raise TypeError(f"Unknown field {name}")
49151

50-
def __init_subclass__(cls, flags: StructFlags = StructFlags.Default):
51-
if any(map(lambda base: issubclass(base, DataStruct) and base.__s_flags__ & StructFlags.Final, cls.__bases__)):
52-
raise TypeError("Cannot inherit from a final struct")
53-
builder: IStructBuilder
54-
if flags & StructFlags.ForceFixedSize:
55-
builder = FixedStructBuilder(flags)
56-
else:
57-
builder = DynamicStructBuilder(flags)
58-
if flags & StructFlags.ForceFixedSize:
59-
if flags & StructFlags.ReorderFields:
60-
builder.reorder_fields()
61-
if not flags & StructFlags.NoAlignment:
62-
builder.align_fields()
63-
64-
potential_fields: typing.List[FieldInfo] = []
65-
66-
if flags & StructFlags.ForceDataOnly:
67-
if any(map(lambda base: not issubclass(base, DataStruct) or not base.__s_flags__ & StructFlags.ForceDataOnly, cls.__bases__)):
68-
raise TypeError(f"{cls.__name__} is marked as data-only but it inherits non-data-only struct.")
69-
70-
paddings = 0
71-
for i, base in enumerate(cls.__bases__):
72-
if not issubclass(base, DataStruct) or base is DataStruct: continue
73-
if base.__s_flags__ & StructFlags.AllowInline:
74-
for name, typ in base.__s_fields__.items():
75-
if issubclass(typ, Padding):
76-
name = f"<padding>_{paddings}"
77-
paddings += 1
78-
potential_fields.append(FieldInfo(name, typ))
79-
else:
80-
potential_fields.append(FieldInfo(f"<base>_{i}", base))
81-
82-
non_data_fields = []
83-
84-
for name, typ in cls.__annotations__.items():
85-
if name.startswith('__') and name.endswith('__'): continue
86-
if issubclass(typ, DataStruct):
87-
if not typ.__s_flags__ & StructFlags.ForceDataOnly:
88-
non_data_fields.append(name)
89-
if issubclass(typ, Type):
90-
potential_fields.append(FieldInfo(name, typ))
91-
else:
92-
non_data_fields.append(name)
93-
94-
# for name, value in vars(cls).items():
95-
# if name.startswith('__') and name.endswith('__'): continue
96-
# typ = type(value)
97-
# if issubclass(typ, DataStruct):
98-
# if not typ.__s_flags__ & StructFlags.ForceDataOnly:
99-
# non_data_fields.append(name)
100-
# if issubclass(typ, Type):
101-
# potential_fields.append(FieldInfo(name, type(value)))
102-
# else:
103-
# non_data_fields.append(name)
104-
105-
if flags & StructFlags.ForceDataOnly and non_data_fields:
106-
raise TypeError(f"{cls.__name__} is marked as data-only but it contains non-data-only fields: {non_data_fields}.")
107-
108-
for field in potential_fields:
109-
builder.add_field(field, allow_overwrite=flags & StructFlags.AllowOverride, force_safe_overwrite=flags & StructFlags.ForceSafeOverride)
110-
111-
fields = builder.build()
112-
if builder.size != -1:
113-
builder = FixedStructBuilder(flags)
114-
for field in fields.values():
115-
builder.add_field(field, allow_overwrite=flags & StructFlags.AllowOverride, force_safe_overwrite=flags & StructFlags.ForceSafeOverride)
116-
if flags & StructFlags.ReorderFields:
117-
builder.reorder_fields()
118-
if not flags & StructFlags.NoAlignment:
119-
builder.align_fields()
120-
fields = builder.build()
121-
122-
for name, field in fields.items():
123-
value = getattr(cls, name, None)
124-
StructField(field.type).__set_name__(cls, name)
125-
if value is not None:
126-
setattr(cls, name, value)
127-
cls.__s_fields__ = {name: field.type for name, field in fields.items()}
128-
cls.__s_flags__ = flags
129-
cls.__s_size__ = builder.size
130-
cls.__s_alignment__ = 0
131-
cls.__slots__ = fields.keys()
132-
133-
@classmethod
134-
def size(cls):
135-
cls.__s_size__
136-
137-
@classmethod
138-
def alignment(cls) -> int:
139-
return cls.__s_alignment__
140-
141152
@classmethod
142-
def from_bytes(cls: typing.Union[typing.Type[T], "DataStruct"], data: typing.Union[bytes, BytesIO]) -> T:
143-
result = cls()
153+
def from_bytes(cls, data: typing.Union[bytes, BytesIO]):
144154
if isinstance(data, bytes):
145155
data = BytesIO(data)
146-
for field, ftype in cls.__s_fields__.items():
147-
value = ftype.from_bytes(data)
148-
setattr(result, field, value)
149-
return result
156+
return cls(**{field.name: field.type.from_bytes(data) for field in cls.__fields__.values()})
150157

151158
def to_bytes(self) -> bytes:
152-
try:
153-
return b"".join(
154-
typ.to_bytes(getattr(self, field, None)) for field, typ in self.__s_fields__.items()
155-
)
156-
except AttributeError:
157-
raise ValueError(f"One of the fields is not initialized") from None
158-
159-
@classmethod
160-
def __is_instance__(cls, instance) -> bool:
161-
return type.__instancecheck__(cls, instance)
162-
163-
@classmethod
164-
def __class_str__(cls) -> str:
165-
if not cls.__s_fields__:
166-
return "{}"
167-
return pprint.pformat(cls.__s_fields__)
159+
data = []
160+
for field in type(self).__fields__.values():
161+
try:
162+
value = getattr(self, field.name)
163+
except AttributeError:
164+
raise ValueError(f"Field '{field.name}' in struct {type(self).__name__} is not initialized") from None
165+
data.append(field.type.to_bytes(value))
166+
return b"".join(data)
168167

169168
def __repr__(self) -> str:
170-
if not self.__s_fields__:
171-
return "{}"
172-
return pprint.pformat({
173-
field: getattr(self, field, None) for field in self.__s_fields__
174-
})
169+
return str({field: getattr(self, field, None) for field in type(self).__fields__})
175170

176171

177172
__all__ = [

0 commit comments

Comments
 (0)