Skip to content

Commit 016372e

Browse files
committed
When encoding values, use a suitable default value if a value is None
1 parent 86c6e08 commit 016372e

File tree

4 files changed

+39
-18
lines changed

4 files changed

+39
-18
lines changed

eip712_structs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
from eip712_structs.types import Address, Array, Boolean, Bytes, Int, String, Uint
44

55
name = 'eip712-structs'
6-
version = '0.1.5'
6+
version = '0.1.6'

eip712_structs/struct.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class MyStruct(EIP712Struct):
2727
struct_instance = MyStruct(some_param='some_value')
2828
"""
2929
def __init__(self, **kwargs):
30-
super(EIP712Struct, self).__init__(self.type_name)
30+
super(EIP712Struct, self).__init__(self.type_name, None)
3131
members = self.get_members()
3232
self.values = dict()
3333
for name, typ in members:

eip712_structs/types.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import Union, Type
2+
from typing import Any, Union, Type
33

44
from eth_utils.crypto import keccak
55
from eth_utils.conversions import to_int
@@ -8,10 +8,17 @@
88
class EIP712Type:
99
"""The base type for members of a struct.
1010
"""
11-
def __init__(self, type_name: str):
11+
def __init__(self, type_name: str, none_val: Any):
1212
self.type_name = type_name
13+
self.none_val = none_val
1314

1415
def encode_value(self, value) -> bytes:
16+
if value is None:
17+
return self._encode_value(self.none_val)
18+
else:
19+
return self._encode_value(value)
20+
21+
def _encode_value(self, value) -> bytes:
1522
"""Given a value, verify it and convert into the format required by the spec.
1623
1724
:param value: A correct input value for the implemented type.
@@ -38,19 +45,19 @@ def __init__(self, member_type: Union[EIP712Type, Type[EIP712Type]], fixed_lengt
3845
type_name = f'{member_type.type_name}[{fixed_length}]'
3946
self.member_type = member_type
4047
self.fixed_length = fixed_length
41-
super(Array, self).__init__(type_name)
48+
super(Array, self).__init__(type_name, [])
4249

43-
def encode_value(self, value):
50+
def _encode_value(self, value):
4451
encoder = self.member_type
4552
encoded_values = [encoder.encode_value(v) for v in value]
4653
return keccak(b''.join(encoded_values))
4754

4855

4956
class Address(EIP712Type):
5057
def __init__(self):
51-
super(Address, self).__init__('address')
58+
super(Address, self).__init__('address', 0)
5259

53-
def encode_value(self, value):
60+
def _encode_value(self, value):
5461
# Some smart conversions - need to get an address as an int
5562
if isinstance(value, bytes):
5663
v = to_int(value)
@@ -63,9 +70,9 @@ def encode_value(self, value):
6370

6471
class Boolean(EIP712Type):
6572
def __init__(self):
66-
super(Boolean, self).__init__('bool')
73+
super(Boolean, self).__init__('bool', False)
6774

68-
def encode_value(self, value):
75+
def _encode_value(self, value):
6976
if value is False:
7077
return Uint(256).encode_value(0)
7178
elif value is True:
@@ -85,9 +92,9 @@ def __init__(self, length: int = 0):
8592
else:
8693
raise ValueError(f'Byte length must be between 1 or 32. Got: {length}')
8794
self.length = length
88-
super(Bytes, self).__init__(type_name)
95+
super(Bytes, self).__init__(type_name, b'')
8996

90-
def encode_value(self, value):
97+
def _encode_value(self, value):
9198
if self.length == 0:
9299
return keccak(value)
93100
else:
@@ -103,18 +110,18 @@ def __init__(self, length: int):
103110
if length < 8 or length > 256 or length % 8 != 0:
104111
raise ValueError(f'Int length must be a multiple of 8, between 8 and 256. Got: {length}')
105112
self.length = length
106-
super(Int, self).__init__(f'int{length}')
113+
super(Int, self).__init__(f'int{length}', 0)
107114

108-
def encode_value(self, value: int):
115+
def _encode_value(self, value: int):
109116
value.to_bytes(self.length // 8, byteorder='big', signed=True) # For validation
110117
return value.to_bytes(32, byteorder='big', signed=True)
111118

112119

113120
class String(EIP712Type):
114121
def __init__(self):
115-
super(String, self).__init__('string')
122+
super(String, self).__init__('string', '')
116123

117-
def encode_value(self, value):
124+
def _encode_value(self, value):
118125
return keccak(text=value)
119126

120127

@@ -124,9 +131,9 @@ def __init__(self, length: int):
124131
if length < 8 or length > 256 or length % 8 != 0:
125132
raise ValueError(f'Uint length must be a multiple of 8, between 8 and 256. Got: {length}')
126133
self.length = length
127-
super(Uint, self).__init__(f'uint{length}')
134+
super(Uint, self).__init__(f'uint{length}', 0)
128135

129-
def encode_value(self, value: int):
136+
def _encode_value(self, value: int):
130137
value.to_bytes(self.length // 8, byteorder='big', signed=False) # For validation
131138
return value.to_bytes(32, byteorder='big', signed=False)
132139

tests/test_encode_data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,17 @@ class Foo(EIP712Struct):
141141
assert sign_bytes[0:2] == start_bytes
142142
assert sign_bytes[2:34] == exp_domain_bytes
143143
assert sign_bytes[34:] == exp_struct_bytes
144+
145+
146+
def test_none_replacement():
147+
class Foo(EIP712Struct):
148+
s = String()
149+
i = Int(256)
150+
151+
foo = Foo(**{})
152+
encoded_val = foo.encode_value()
153+
assert len(encoded_val) == 64
154+
155+
empty_string_hash = keccak(text='')
156+
assert encoded_val[0:32] == empty_string_hash
157+
assert encoded_val[32:] == bytes(32)

0 commit comments

Comments
 (0)