Skip to content

Commit 4ee6ab9

Browse files
committed
Fix basic type handling, add typing tests
1 parent 13ce5b2 commit 4ee6ab9

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

eip712_structs/types.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,25 @@ def encode_value(self, value) -> bytes:
1919
"""
2020
pass
2121

22+
def __eq__(self, other):
23+
self_type = getattr(self, 'type_name')
24+
other_type = getattr(other, 'type_name')
25+
26+
return self_type is not None and self_type == other_type
27+
28+
def __hash__(self):
29+
return hash(self.type_name)
30+
2231

2332
class Array(EIP712Type):
2433
def __init__(self, member_type: Union[EIP712Type, Type[EIP712Type]], fixed_length: int = 0):
34+
fixed_length = int(fixed_length)
2535
if fixed_length == 0:
2636
type_name = f'{member_type.type_name}[]'
2737
else:
2838
type_name = f'{member_type.type_name}[{fixed_length}]'
2939
self.member_type = member_type
40+
self.fixed_length = fixed_length
3041
super(Array, self).__init__(type_name)
3142

3243
def encode_value(self, value):
@@ -65,6 +76,7 @@ def encode_value(self, value):
6576

6677
class Bytes(EIP712Type):
6778
def __init__(self, length: int = 0):
79+
length = int(length)
6880
if length == 0:
6981
# Special case: Length of 0 means a dynamic bytes type
7082
type_name = 'bytes'
@@ -87,6 +99,7 @@ def encode_value(self, value):
8799

88100
class Int(EIP712Type):
89101
def __init__(self, length: int):
102+
length = int(length)
90103
if length < 8 or length > 256 or length % 8 != 0:
91104
raise ValueError(f'Int length must be a multiple of 8, between 8 and 256. Got: {length}')
92105
self.length = length
@@ -107,6 +120,7 @@ def encode_value(self, value):
107120

108121
class Uint(EIP712Type):
109122
def __init__(self, length: int):
123+
length = int(length)
110124
if length < 8 or length > 256 or length % 8 != 0:
111125
raise ValueError(f'Uint length must be a multiple of 8, between 8 and 256. Got: {length}')
112126
self.length = length
@@ -145,13 +159,13 @@ def from_solidity_type(solidity_type: str):
145159

146160
base_type = solidity_type_map[type_name]
147161
if opt_len:
148-
type_instance = base_type(opt_len)
162+
type_instance = base_type(int(opt_len))
149163
else:
150164
type_instance = base_type()
151165

152166
if is_array:
153167
if array_len:
154-
result = Array(type_instance, array_len)
168+
result = Array(type_instance, int(array_len))
155169
else:
156170
result = Array(type_instance)
157171
return result

tests/test_types.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
from eip712_structs import Address, Array, Boolean, Bytes, Int, String, Uint, EIP712Struct
4+
from eip712_structs.types import from_solidity_type
45

56

67
def test_bytes_validation():
@@ -56,3 +57,31 @@ class Foo(EIP712Struct):
5657

5758
assert Array(Foo).type_name == 'Foo[]'
5859
assert Array(Foo, 10).type_name == 'Foo[10]'
60+
61+
62+
def test_length_str_typing():
63+
# Ensure that if length is given as a string, it's typecast to int
64+
assert Array(String(), '5').fixed_length == 5
65+
assert Bytes('10').length == 10
66+
assert Int('128').length == 128
67+
assert Uint('128').length == 128
68+
69+
70+
def test_from_solidity_type():
71+
assert from_solidity_type('address') == Address()
72+
assert from_solidity_type('bool') == Boolean()
73+
assert from_solidity_type('bytes') == Bytes()
74+
assert from_solidity_type('bytes32') == Bytes(32)
75+
assert from_solidity_type('int128') == Int(128)
76+
assert from_solidity_type('string') == String()
77+
assert from_solidity_type('uint256') == Uint(256)
78+
79+
assert from_solidity_type('address[]') == Array(Address())
80+
assert from_solidity_type('address[10]') == Array(Address(), 10)
81+
assert from_solidity_type('bytes16[32]') == Array(Bytes(16), 32)
82+
83+
# Sanity check that equivalency is working as expected
84+
assert from_solidity_type('bytes32') != Bytes(31)
85+
assert from_solidity_type('bytes16[32]') != Array(Bytes(16), 31)
86+
assert from_solidity_type('bytes16[32]') != Array(Bytes(), 32)
87+
assert from_solidity_type('bytes16[32]') != Array(Bytes(8), 32)

0 commit comments

Comments
 (0)