diff --git a/betterproto/__init__.py b/betterproto/__init__.py index 6c07feb0c..b25eaa65a 100644 --- a/betterproto/__init__.py +++ b/betterproto/__init__.py @@ -255,12 +255,13 @@ class Enum(int, enum.Enum): """Protocol buffers enumeration base class. Acts like `enum.IntEnum`.""" @classmethod - def from_string(cls, name: str) -> int: - """Return the value which corresponds to the string name.""" + def from_value(cls, value: Union[str, int]) -> int: + """Return the value which corresponds to the string value.""" try: - return cls.__members__[name] - except KeyError as e: - raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e + value = int(value) + return cls(value) + except (KeyError, ValueError, TypeError) as e: + raise ValueError(f"Unknown value {value} for enum {cls.__name__}") from e def _pack_fmt(proto_type: str) -> str: @@ -611,8 +612,16 @@ def __bytes__(self) -> bytes: output += _serialize_single(meta.number, TYPE_BYTES, buf) else: for item in value: - output += _serialize_single( - meta.number, meta.proto_type, item, wraps=meta.wraps or "" + output += ( + _serialize_single( + meta.number, + meta.proto_type, + item, + wraps=meta.wraps or "", + ) + # if it's an empty message it still needs to be represented + # as an item in the repeated list + or b"\n\x00" ) elif isinstance(value, dict): for k, v in value.items(): @@ -636,8 +645,12 @@ def __bytes__(self) -> bytes: @classmethod def _type_hint(cls, field_name: str) -> Type: - module = inspect.getmodule(cls) - type_hints = get_type_hints(cls, vars(module)) + global_vars = {} + for base in inspect.getmro(cls): + module = inspect.getmodule(base) + global_vars.update(vars(module)) + + type_hints = get_type_hints(cls, global_vars) return type_hints[field_name] @classmethod @@ -831,9 +844,9 @@ def to_dict( self._betterproto.cls_by_field[field_name] ) # type: ignore if isinstance(v, list): - output[cased_name] = [enum_values[e].name for e in v] + output[cased_name] = [enum_values[e].value for e in v] else: - output[cased_name] = enum_values[v].name + output[cased_name] = enum_values[v].value else: output[cased_name] = v return output @@ -888,9 +901,9 @@ def from_dict(self: T, value: dict) -> T: elif meta.proto_type == TYPE_ENUM: enum_cls = self._betterproto.cls_by_field[field_name] if isinstance(v, list): - v = [enum_cls.from_string(e) for e in v] - elif isinstance(v, str): - v = enum_cls.from_string(v) + v = [enum_cls.from_value(e) for e in v] + else: + v = enum_cls.from_value(v) if v is not None: setattr(self, field_name, v)