|
1 | 1 | from .common import * |
2 | | -from .common import Padding |
3 | | -from .struct_builder import FieldInfo, FixedStructBuilder, DynamicStructBuilder, IStructBuilder, StructFlags |
| 2 | +from .common import TypeMeta |
| 3 | +from .error import InheritanceError, UnoverridbaleFieldError, OverrideError, SizeError, UnsafeOverrideError |
| 4 | +from .field import FieldFlags, StructField, StructPaddingField |
| 5 | +from .struct_builder import FieldInfo, StructBuilder, StructFlags |
4 | 6 |
|
5 | 7 | from io import BytesIO |
6 | | -import pprint |
7 | 8 | import typing |
8 | 9 |
|
9 | 10 |
|
10 | | -class StructField: |
11 | | - _type: typing.Type[Type] |
12 | | - _name: str |
| 11 | +class DataStructMeta(TypeMeta): |
| 12 | + __sflags__: StructFlags |
| 13 | + __size__: int |
| 14 | + __fields__: typing.Dict[str, FieldInfo] |
| 15 | + |
| 16 | + @property |
| 17 | + def fields(self) -> typing.List[FieldInfo]: |
| 18 | + return list(self.__fields__.values()) |
| 19 | + |
| 20 | + @property |
| 21 | + def is_final(self) -> bool: |
| 22 | + """ |
| 23 | + Returns True if this struct is final. |
| 24 | + """ |
| 25 | + return bool(self.__sflags__ & StructFlags.Final) |
| 26 | + |
| 27 | + @property |
| 28 | + def is_protected(self) -> bool: |
| 29 | + """ |
| 30 | + Returns True if the struct is protected. |
| 31 | + """ |
| 32 | + return bool(self.__sflags__ & StructFlags.Protected) |
| 33 | + |
| 34 | + @property |
| 35 | + def alignment(self) -> int: |
| 36 | + """ |
| 37 | + Returns the alignment of the struct (which is the size of the struct). |
| 38 | + """ |
| 39 | + return self.size |
| 40 | + |
| 41 | + @property |
| 42 | + def size(self) -> int: |
| 43 | + """ |
| 44 | + Returns the size of the struct. If the struct has dynamic size, this will return -1. |
| 45 | + """ |
| 46 | + return self.__size__ |
| 47 | + |
| 48 | + @property |
| 49 | + def is_dynamic_size(self): |
| 50 | + return self.size == -1 |
| 51 | + |
| 52 | + @property |
| 53 | + def is_fixed_size(self): |
| 54 | + return not self.is_dynamic_size |
13 | 55 |
|
14 | | - def __init__(self, typ: typing.Type[Type]) -> None: |
15 | | - self._type = typ |
| 56 | + @classmethod |
| 57 | + def __prepare__(cls, _, bases, flags: StructFlags = StructFlags.Default, **__): |
| 58 | + namespace = {} |
16 | 59 |
|
17 | | - def __set_name__(self, owner, name): |
18 | | - self._name = f"<field>_{name}" |
19 | | - setattr(owner, name, self) |
| 60 | + fields: typing.List[FieldInfo] = [] |
20 | 61 |
|
21 | | - def __get__(self, instance, _): |
22 | | - if instance is None: |
23 | | - return self |
24 | | - return getattr(instance, self._name) |
| 62 | + for base in bases: |
| 63 | + if not isinstance(base, DataStructMeta): |
| 64 | + continue |
| 65 | + if base.is_final: |
| 66 | + raise InheritanceError(f"Cannot inherit from a final struct '{base.__name__}'.") |
| 67 | + fields.extend(base.__fields__.values()) |
25 | 68 |
|
26 | | - def __set__(self, instance, value): |
27 | | - if not isinstance(value, self._type): |
28 | | - raise TypeError(f"Expected {self._type}, got {type(value)}") |
29 | | - setattr(instance, self._name, value) |
| 69 | + namespace['__fields__'] = fields |
| 70 | + namespace['__sflags__'] = flags |
30 | 71 |
|
| 72 | + return namespace |
31 | 73 |
|
32 | | -class DataStruct(Type): |
33 | | - """Base type for all data structures.""" |
| 74 | + def __new__(cls, name, bases, namespace, flags: StructFlags = StructFlags.Default, **_): |
| 75 | + base_fields: typing.List[FieldInfo] = namespace['__fields__'] |
| 76 | + flags: StructFlags = namespace['__sflags__'] |
| 77 | + |
| 78 | + try: |
| 79 | + annotations = namespace['__annotations__'] |
| 80 | + except KeyError: |
| 81 | + annotations = {} |
| 82 | + |
| 83 | + fields: typing.Dict[str, FieldInfo] = {} |
| 84 | + |
| 85 | + defaults = {n: v for n, v in namespace.items() if n in annotations} |
| 86 | + namespace = {n: v for n, v in namespace.items() if n not in annotations} |
| 87 | + |
| 88 | + field_flags: FieldFlags = FieldFlags.NONE |
| 89 | + if flags & StructFlags.Protected: |
| 90 | + field_flags |= FieldFlags.Protected |
| 91 | + |
| 92 | + for field in base_fields + [FieldInfo(name, typ, field_flags) for name, typ in annotations.items()]: |
| 93 | + if field.name in fields: |
| 94 | + if fields[field.name].is_protected: |
| 95 | + raise UnoverridbaleFieldError(f"Field '{field.name}' is protected and cannot be overridden.") |
| 96 | + if flags & StructFlags.TypeSafeOverride and field.type != fields[field.name].type: |
| 97 | + raise UnsafeOverrideError(f"Field '{field.name}' is overriding a field from a base type with a different type.") |
| 98 | + if not flags & StructFlags.AllowOverride: |
| 99 | + raise OverrideError(f"Field '{field.name}' is overriding a field from a base type.") |
| 100 | + fields[field.name] = field |
| 101 | + |
| 102 | + builder = StructBuilder() |
| 103 | + for field in fields.values(): |
| 104 | + builder.add_field(field) |
| 105 | + if flags & StructFlags.ReorderFields: |
| 106 | + builder.reorder_fields() |
| 107 | + if not flags & StructFlags.NoAlignment: |
| 108 | + builder.align_fields(None if flags & StructFlags.AlignAuto else (1 << flags & StructFlags.AlignmentMask)) |
| 109 | + fields = {field.name: field for field in builder.build()} |
| 110 | + |
| 111 | + if flags & StructFlags.ForceFixedSize and builder.is_dynamic_size: |
| 112 | + raise SizeError(f"Cannot force fixed size on a dynamic struct '{name}'.") |
| 113 | + |
| 114 | + namespace['__slots__'] = tuple(f"__field_{field}__" for field in fields.keys()) |
| 115 | + namespace['__defaults__'] = defaults |
| 116 | + namespace['__fields__'] = fields |
| 117 | + namespace['__size__'] = builder.size |
34 | 118 |
|
35 | | - __s_fields__: typing.Dict[str, typing.Type[Type]] |
36 | | - __s_flags__: StructFlags = StructFlags.Default |
37 | | - __s_size__: int |
38 | | - __s_alignment__: int |
| 119 | + namespace.update({ |
| 120 | + field.name: StructField(field.type) if not issubclass(field.type, Padding) else StructPaddingField() |
| 121 | + for field in fields.values() |
| 122 | + }) |
| 123 | + |
| 124 | + return super().__new__(cls, name, bases, namespace, **_) |
| 125 | + |
| 126 | + def __iter__(self): |
| 127 | + return iter(self.__fields__.values()) |
| 128 | + |
| 129 | + def __repr__(self) -> str: |
| 130 | + if not self.__fields__: |
| 131 | + return "{}" |
| 132 | + return str({field.name: field.type for field in self.__fields__.values()}) |
| 133 | + |
| 134 | + |
| 135 | +class DataStruct(Type, metaclass=DataStructMeta): |
| 136 | + """Base type for all data structures.""" |
39 | 137 |
|
40 | 138 | def __init__(self, **kwargs) -> None: |
41 | 139 | super().__init__() |
42 | 140 | try: |
43 | | - for name, value in kwargs.items(): |
44 | | - # if not isinstance(value, self.__fields__[name]): |
45 | | - # raise TypeError(f"Expected {self.__fields__[name]}, got {type(value)}") |
| 141 | + for name in type(self).__fields__: |
| 142 | + if name in kwargs: |
| 143 | + value = kwargs[name] |
| 144 | + elif name in type(self).__defaults__: |
| 145 | + value = type(self).__defaults__[name] |
| 146 | + else: |
| 147 | + continue |
46 | 148 | setattr(self, name, value) |
47 | 149 | except KeyError: |
48 | 150 | raise TypeError(f"Unknown field {name}") |
49 | 151 |
|
50 | | - def __init_subclass__(cls, flags: StructFlags = StructFlags.Default): |
51 | | - if any(map(lambda base: issubclass(base, DataStruct) and base.__s_flags__ & StructFlags.Final, cls.__bases__)): |
52 | | - raise TypeError("Cannot inherit from a final struct") |
53 | | - builder: IStructBuilder |
54 | | - if flags & StructFlags.ForceFixedSize: |
55 | | - builder = FixedStructBuilder(flags) |
56 | | - else: |
57 | | - builder = DynamicStructBuilder(flags) |
58 | | - if flags & StructFlags.ForceFixedSize: |
59 | | - if flags & StructFlags.ReorderFields: |
60 | | - builder.reorder_fields() |
61 | | - if not flags & StructFlags.NoAlignment: |
62 | | - builder.align_fields() |
63 | | - |
64 | | - potential_fields: typing.List[FieldInfo] = [] |
65 | | - |
66 | | - if flags & StructFlags.ForceDataOnly: |
67 | | - if any(map(lambda base: not issubclass(base, DataStruct) or not base.__s_flags__ & StructFlags.ForceDataOnly, cls.__bases__)): |
68 | | - raise TypeError(f"{cls.__name__} is marked as data-only but it inherits non-data-only struct.") |
69 | | - |
70 | | - paddings = 0 |
71 | | - for i, base in enumerate(cls.__bases__): |
72 | | - if not issubclass(base, DataStruct) or base is DataStruct: continue |
73 | | - if base.__s_flags__ & StructFlags.AllowInline: |
74 | | - for name, typ in base.__s_fields__.items(): |
75 | | - if issubclass(typ, Padding): |
76 | | - name = f"<padding>_{paddings}" |
77 | | - paddings += 1 |
78 | | - potential_fields.append(FieldInfo(name, typ)) |
79 | | - else: |
80 | | - potential_fields.append(FieldInfo(f"<base>_{i}", base)) |
81 | | - |
82 | | - non_data_fields = [] |
83 | | - |
84 | | - for name, typ in cls.__annotations__.items(): |
85 | | - if name.startswith('__') and name.endswith('__'): continue |
86 | | - if issubclass(typ, DataStruct): |
87 | | - if not typ.__s_flags__ & StructFlags.ForceDataOnly: |
88 | | - non_data_fields.append(name) |
89 | | - if issubclass(typ, Type): |
90 | | - potential_fields.append(FieldInfo(name, typ)) |
91 | | - else: |
92 | | - non_data_fields.append(name) |
93 | | - |
94 | | - # for name, value in vars(cls).items(): |
95 | | - # if name.startswith('__') and name.endswith('__'): continue |
96 | | - # typ = type(value) |
97 | | - # if issubclass(typ, DataStruct): |
98 | | - # if not typ.__s_flags__ & StructFlags.ForceDataOnly: |
99 | | - # non_data_fields.append(name) |
100 | | - # if issubclass(typ, Type): |
101 | | - # potential_fields.append(FieldInfo(name, type(value))) |
102 | | - # else: |
103 | | - # non_data_fields.append(name) |
104 | | - |
105 | | - if flags & StructFlags.ForceDataOnly and non_data_fields: |
106 | | - raise TypeError(f"{cls.__name__} is marked as data-only but it contains non-data-only fields: {non_data_fields}.") |
107 | | - |
108 | | - for field in potential_fields: |
109 | | - builder.add_field(field, allow_overwrite=flags & StructFlags.AllowOverride, force_safe_overwrite=flags & StructFlags.ForceSafeOverride) |
110 | | - |
111 | | - fields = builder.build() |
112 | | - if builder.size != -1: |
113 | | - builder = FixedStructBuilder(flags) |
114 | | - for field in fields.values(): |
115 | | - builder.add_field(field, allow_overwrite=flags & StructFlags.AllowOverride, force_safe_overwrite=flags & StructFlags.ForceSafeOverride) |
116 | | - if flags & StructFlags.ReorderFields: |
117 | | - builder.reorder_fields() |
118 | | - if not flags & StructFlags.NoAlignment: |
119 | | - builder.align_fields() |
120 | | - fields = builder.build() |
121 | | - |
122 | | - for name, field in fields.items(): |
123 | | - value = getattr(cls, name, None) |
124 | | - StructField(field.type).__set_name__(cls, name) |
125 | | - if value is not None: |
126 | | - setattr(cls, name, value) |
127 | | - cls.__s_fields__ = {name: field.type for name, field in fields.items()} |
128 | | - cls.__s_flags__ = flags |
129 | | - cls.__s_size__ = builder.size |
130 | | - cls.__s_alignment__ = 0 |
131 | | - cls.__slots__ = fields.keys() |
132 | | - |
133 | | - @classmethod |
134 | | - def size(cls): |
135 | | - cls.__s_size__ |
136 | | - |
137 | | - @classmethod |
138 | | - def alignment(cls) -> int: |
139 | | - return cls.__s_alignment__ |
140 | | - |
141 | 152 | @classmethod |
142 | | - def from_bytes(cls: typing.Union[typing.Type[T], "DataStruct"], data: typing.Union[bytes, BytesIO]) -> T: |
143 | | - result = cls() |
| 153 | + def from_bytes(cls, data: typing.Union[bytes, BytesIO]): |
144 | 154 | if isinstance(data, bytes): |
145 | 155 | data = BytesIO(data) |
146 | | - for field, ftype in cls.__s_fields__.items(): |
147 | | - value = ftype.from_bytes(data) |
148 | | - setattr(result, field, value) |
149 | | - return result |
| 156 | + return cls(**{field.name: field.type.from_bytes(data) for field in cls.__fields__.values()}) |
150 | 157 |
|
151 | 158 | def to_bytes(self) -> bytes: |
152 | | - try: |
153 | | - return b"".join( |
154 | | - typ.to_bytes(getattr(self, field, None)) for field, typ in self.__s_fields__.items() |
155 | | - ) |
156 | | - except AttributeError: |
157 | | - raise ValueError(f"One of the fields is not initialized") from None |
158 | | - |
159 | | - @classmethod |
160 | | - def __is_instance__(cls, instance) -> bool: |
161 | | - return type.__instancecheck__(cls, instance) |
162 | | - |
163 | | - @classmethod |
164 | | - def __class_str__(cls) -> str: |
165 | | - if not cls.__s_fields__: |
166 | | - return "{}" |
167 | | - return pprint.pformat(cls.__s_fields__) |
| 159 | + data = [] |
| 160 | + for field in type(self).__fields__.values(): |
| 161 | + try: |
| 162 | + value = getattr(self, field.name) |
| 163 | + except AttributeError: |
| 164 | + raise ValueError(f"Field '{field.name}' in struct {type(self).__name__} is not initialized") from None |
| 165 | + data.append(field.type.to_bytes(value)) |
| 166 | + return b"".join(data) |
168 | 167 |
|
169 | 168 | def __repr__(self) -> str: |
170 | | - if not self.__s_fields__: |
171 | | - return "{}" |
172 | | - return pprint.pformat({ |
173 | | - field: getattr(self, field, None) for field in self.__s_fields__ |
174 | | - }) |
| 169 | + return str({field: getattr(self, field, None) for field in type(self).__fields__}) |
175 | 170 |
|
176 | 171 |
|
177 | 172 | __all__ = [ |
|
0 commit comments