Skip to content

Commit 79275b6

Browse files
committed
Add array type to parity test
1 parent f64f6bc commit 79275b6

File tree

2 files changed

+47
-9
lines changed

2 files changed

+47
-9
lines changed

tests/contracts/hash_test_contract.sol

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,33 @@ contract TestContract {
1919
bytes30 bytes_30;
2020
bytes dyn_bytes;
2121
Bar bar;
22+
bytes1[] arr;
2223
}
2324

2425
string constant public BarSig = "Bar(uint256 bar_uint)";
25-
string constant public FooSig = "Foo(string s,uint256 u_i,int8 s_i,address a,bool b,bytes30 bytes_30,bytes dyn_bytes,Bar bar)Bar(uint256 bar_uint)";
26+
string constant public FooSig = "Foo(string s,uint256 u_i,int8 s_i,address a,bool b,bytes30 bytes_30,bytes dyn_bytes,Bar bar,bytes1[] arr)Bar(uint256 bar_uint)";
2627

2728
bytes32 constant public Bar_TYPEHASH = keccak256(
2829
abi.encodePacked("Bar(uint256 bar_uint)")
2930
);
3031
bytes32 constant public Foo_TYPEHASH = keccak256(
31-
abi.encodePacked("Foo(string s,uint256 u_i,int8 s_i,address a,bool b,bytes30 bytes_30,bytes dyn_bytes,Bar bar)Bar(uint256 bar_uint)")
32+
abi.encodePacked("Foo(string s,uint256 u_i,int8 s_i,address a,bool b,bytes30 bytes_30,bytes dyn_bytes,Bar bar,bytes1[] arr)Bar(uint256 bar_uint)")
3233
);
3334

3435
/******************/
3536
/* Hash Functions */
3637
/******************/
38+
function encodeBytes1Array(bytes1[] memory arr) public pure returns (bytes32) {
39+
uint256 len = arr.length;
40+
bytes32[] memory padded = new bytes32[](len);
41+
for (uint256 i = 0; i < len; i++) {
42+
padded[i] = bytes32(arr[i]);
43+
}
44+
return keccak256(
45+
abi.encodePacked(padded)
46+
);
47+
}
48+
3749
function hashBarStruct(Bar memory bar) public pure returns (bytes32) {
3850
return keccak256(abi.encode(
3951
Bar_TYPEHASH,
@@ -51,7 +63,8 @@ contract TestContract {
5163
foo.b,
5264
foo.bytes_30,
5365
keccak256(abi.encodePacked(foo.dyn_bytes)),
54-
hashBarStruct(foo.bar)
66+
hashBarStruct(foo.bar),
67+
encodeBytes1Array(foo.arr)
5568
));
5669
}
5770

@@ -71,7 +84,8 @@ contract TestContract {
7184
bool b,
7285
bytes30 bytes_30,
7386
bytes memory dyn_bytes,
74-
uint256 bar_uint
87+
uint256 bar_uint,
88+
bytes1[] memory arr
7589
) public pure returns (bytes32) {
7690
// Construct Foo struct with basic types
7791
Foo memory foo;
@@ -82,6 +96,7 @@ contract TestContract {
8296
foo.b = b;
8397
foo.bytes_30 = bytes_30;
8498
foo.dyn_bytes = dyn_bytes;
99+
foo.arr = arr;
85100

86101
// Construct Bar struct and add it to Foo
87102
Bar memory bar;

tests/test_chain_parity.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,24 @@
44
from requests.exceptions import ConnectionError
55
from web3 import HTTPProvider, Web3
66

7-
from eip712_structs import EIP712Struct, String, Uint, Int, Address, Boolean, Bytes
7+
from eip712_structs import EIP712Struct, String, Uint, Int, Address, Boolean, Bytes, Array
88

99

1010
@pytest.fixture(scope='module')
1111
def w3():
12+
"""Provide a Web3 client to interact with a local chain."""
1213
client = Web3(HTTPProvider('http://localhost:8545'))
1314
client.eth.defaultAccount = client.eth.accounts[0]
1415
return client
1516

1617

1718
@pytest.fixture(scope='module')
1819
def contract(w3):
20+
"""Deploys the test contract to the local chain, and returns a Web3.py Contract to interact with it.
21+
22+
Note this expects the contract to be compiled already.
23+
This project's docker-compose config pulls a solc container to do this for you.
24+
"""
1925
base_path = 'tests/contracts/build/TestContract'
2026
with open(f'{base_path}.abi', 'r') as f:
2127
abi = f.read()
@@ -31,6 +37,7 @@ def contract(w3):
3137

3238

3339
def skip_this_module():
40+
"""If we can't reach a local chain, then all tests in this module are skipped."""
3441
client = Web3(HTTPProvider('http://localhost:8545'))
3542
try:
3643
client.eth.accounts
@@ -39,6 +46,7 @@ def skip_this_module():
3946
return False
4047

4148

49+
# Implicitly adds this ``skipif`` mark to the tests below.
4250
pytestmark = pytest.mark.skipif(skip_this_module(), reason='No accessible test chain.')
4351

4452

@@ -47,6 +55,7 @@ class Bar(EIP712Struct):
4755
bar_uint = Uint(256)
4856

4957

58+
# TODO Add Array type (w/ appropriate test updates) to this struct.
5059
class Foo(EIP712Struct):
5160
s = String()
5261
u_i = Uint(256)
@@ -56,10 +65,12 @@ class Foo(EIP712Struct):
5665
bytes_30 = Bytes(30)
5766
dyn_bytes = Bytes()
5867
bar = Bar
68+
arr = Array(Bytes(1))
5969

6070

61-
def get_chain_hash(contract, s, u_i, s_i, a, b, bytes_30, dyn_bytes, bar_uint) -> bytes:
62-
result = contract.functions.hashFooStructFromParams(s, u_i, s_i, a, b, bytes_30, dyn_bytes, bar_uint).call()
71+
def get_chain_hash(contract, s, u_i, s_i, a, b, bytes_30, dyn_bytes, bar_uint, arr) -> bytes:
72+
"""Uses the contract to create and hash a Foo struct with the given parameters."""
73+
result = contract.functions.hashFooStructFromParams(s, u_i, s_i, a, b, bytes_30, dyn_bytes, bar_uint, arr).call()
6374
return result
6475

6576

@@ -81,23 +92,35 @@ def test_encoded_types(contract):
8192
remote_foo_hash = contract.functions.Foo_TYPEHASH().call()
8293
assert local_foo_hash == remote_foo_hash
8394

95+
array_type = Array(Bytes(1))
96+
bytes_array = [os.urandom(1) for _ in range(5)]
97+
local_encoded_array = array_type.encode_value(bytes_array)
98+
remote_encoded_array = contract.functions.encodeBytes1Array(bytes_array).call()
99+
assert local_encoded_array == remote_encoded_array
100+
84101

85102
def test_chain_hash_matches(contract):
103+
"""Assert that the hashes we derive locally match the hashes derived on-chain."""
104+
105+
# Initialize basic values
86106
s = 'some string'
87107
u_i = 1234
88108
s_i = -7
89109
a = Web3.toChecksumAddress(f'0x{os.urandom(20).hex()}')
90110
b = True
91111
bytes_30 = os.urandom(30)
92112
dyn_bytes = os.urandom(50)
113+
arr = [os.urandom(1) for _ in range(5)]
93114

115+
# Initialize a Bar struct, and check it standalone
94116
bar_uint = 1337
95117
bar_struct = Bar(bar_uint=bar_uint)
96118
local_bar_hash = bar_struct.hash_struct()
97119
remote_bar_hash = contract.functions.hashBarStructFromParams(bar_uint).call()
98120
assert local_bar_hash == remote_bar_hash
99121

100-
foo_struct = Foo(s=s, u_i=u_i, s_i=s_i, a=a, b=b, bytes_30=bytes_30, dyn_bytes=dyn_bytes, bar=bar_struct)
122+
# Initialize a Foo struct (including the Bar struct above) and check the hashes
123+
foo_struct = Foo(s=s, u_i=u_i, s_i=s_i, a=a, b=b, bytes_30=bytes_30, dyn_bytes=dyn_bytes, bar=bar_struct, arr=arr)
101124
local_foo_hash = foo_struct.hash_struct()
102-
remote_foo_hash = get_chain_hash(contract, s, u_i, s_i, a, b, bytes_30, dyn_bytes, bar_uint)
125+
remote_foo_hash = get_chain_hash(contract, s, u_i, s_i, a, b, bytes_30, dyn_bytes, bar_uint, arr)
103126
assert local_foo_hash == remote_foo_hash

0 commit comments

Comments
 (0)