Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,13 @@ class Enum(int, enum.Enum):
"""Protocol buffers enumeration base class. Acts like `enum.IntEnum`."""

@classmethod
def from_string(cls, name: str) -> int:
"""Return the value which corresponds to the string name."""
def from_value(cls, value: Union[str, int]) -> int:
"""Return the value which corresponds to the string value."""
try:
return cls.__members__[name]
except KeyError as e:
raise ValueError(f"Unknown value {name} for enum {cls.__name__}") from e
value = int(value)
return cls(value)
except (KeyError, ValueError, TypeError) as e:
raise ValueError(f"Unknown value {value} for enum {cls.__name__}") from e


def _pack_fmt(proto_type: str) -> str:
Expand Down Expand Up @@ -611,8 +612,16 @@ def __bytes__(self) -> bytes:
output += _serialize_single(meta.number, TYPE_BYTES, buf)
else:
for item in value:
output += _serialize_single(
meta.number, meta.proto_type, item, wraps=meta.wraps or ""
output += (
_serialize_single(
meta.number,
meta.proto_type,
item,
wraps=meta.wraps or "",
)
# if it's an empty message it still needs to be represented
# as an item in the repeated list
or b"\n\x00"
)
elif isinstance(value, dict):
for k, v in value.items():
Expand All @@ -636,8 +645,12 @@ def __bytes__(self) -> bytes:

@classmethod
def _type_hint(cls, field_name: str) -> Type:
module = inspect.getmodule(cls)
type_hints = get_type_hints(cls, vars(module))
global_vars = {}
for base in inspect.getmro(cls):
module = inspect.getmodule(base)
global_vars.update(vars(module))

type_hints = get_type_hints(cls, global_vars)
return type_hints[field_name]

@classmethod
Expand Down Expand Up @@ -831,9 +844,9 @@ def to_dict(
self._betterproto.cls_by_field[field_name]
) # type: ignore
if isinstance(v, list):
output[cased_name] = [enum_values[e].name for e in v]
output[cased_name] = [enum_values[e].value for e in v]
else:
output[cased_name] = enum_values[v].name
output[cased_name] = enum_values[v].value
else:
output[cased_name] = v
return output
Expand Down Expand Up @@ -888,9 +901,9 @@ def from_dict(self: T, value: dict) -> T:
elif meta.proto_type == TYPE_ENUM:
enum_cls = self._betterproto.cls_by_field[field_name]
if isinstance(v, list):
v = [enum_cls.from_string(e) for e in v]
elif isinstance(v, str):
v = enum_cls.from_string(v)
v = [enum_cls.from_value(e) for e in v]
else:
v = enum_cls.from_value(v)

if v is not None:
setattr(self, field_name, v)
Expand Down