diff --git a/scripts/decode.py b/scripts/decode.py new file mode 100644 index 0000000..4cca0c0 --- /dev/null +++ b/scripts/decode.py @@ -0,0 +1,669 @@ +""" +decode.py — Decodes OKX DEX Router calldata. +Integrates decode_functions.py and decode_fee.py into a single script. + +Usage: + python decode.py + +Requires: pip install eth-abi eth-utils "eth-hash[pycryptodome]" +""" + +import sys +import json +from eth_abi import decode as abi_decode +from eth_utils import keccak +from eth_utils import to_checksum_address as _eth_checksum + +# ============================================================================ +# Masks (core/masks.js) +# ============================================================================ + +ADDRESS_MASK = 0x000000000000000000000000ffffffffffffffffffffffffffffffffffffffff +ONE_FOR_ZERO_MASK = 0x8000000000000000000000000000000000000000000000000000000000000000 +WETH_UNWRAP_MASK = 0x2000000000000000000000000000000000000000000000000000000000000000 +ORDER_ID_MASK = 0x1fffffffffffffffffffffff0000000000000000000000000000000000000000 +WEIGHT_MASK = 0x00000000000000000000ffff0000000000000000000000000000000000000000 +REVERSE_MASK = 0x8000000000000000000000000000000000000000000000000000000000000000 +IS_TOKEN0_TAX_MASK = 0x1000000000000000000000000000000000000000000000000000000000000000 +IS_TOKEN1_TAX_MASK = 0x2000000000000000000000000000000000000000000000000000000000000000 +WETH_MASK = 0x4000000000000000000000000000000000000000000000000000000000000000 +NUMERATOR_MASK = 0x0000000000000000ffffffff0000000000000000000000000000000000000000 +SWAP_AMOUNT_MASK = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff +DAG_INPUT_INDEX_MASK = 0x0000000000000000ff0000000000000000000000000000000000000000000000 +DAG_OUTPUT_INDEX_MASK = 0x000000000000000000ff00000000000000000000000000000000000000000000 +MODE_NO_TRANSFER_MASK = 0x0800000000000000000000000000000000000000000000000000000000000000 +MODE_BY_INVEST_MASK = 0x0400000000000000000000000000000000000000000000000000000000000000 +MODE_PERMIT2_MASK = 0x0200000000000000000000000000000000000000000000000000000000000000 + +# ============================================================================ +# ABI Definitions (core/abi.js) +# ============================================================================ + +_BR = [ # BaseRequest components + {'type': 'uint256', 'name': 'fromToken'}, + {'type': 'address', 'name': 'toToken'}, + {'type': 'uint256', 'name': 'fromTokenAmount'}, + {'type': 'uint256', 'name': 'minReturnAmount'}, + {'type': 'uint256', 'name': 'deadLine'}, +] +_RP = [ # RouterPath components + {'type': 'address[]', 'name': 'mixAdapters'}, + {'type': 'address[]', 'name': 'assetTo'}, + {'type': 'uint256[]', 'name': 'rawData'}, + {'type': 'bytes[]', 'name': 'extraData'}, + {'type': 'uint256', 'name': 'fromToken'}, +] +_ST = [ # Settler (extraData) components + {'type': 'uint256', 'name': 'fromToken'}, + {'type': 'address', 'name': 'toToken'}, + {'type': 'address', 'name': 'receiver'}, + {'type': 'address', 'name': 'payer'}, + {'type': 'uint256', 'name': 'fromTokenAmount'}, + {'type': 'uint256', 'name': 'minReturnAmount'}, + {'type': 'uint256', 'name': 'deadLine'}, + {'type': 'uint256', 'name': 'orderId'}, + {'type': 'bool', 'name': 'isToB'}, + {'type': 'bytes', 'name': 'settlerData'}, +] + +def _p(name, typ, components=None): + d = {'name': name, 'type': typ} + if components is not None: + d['components'] = components + return d + +_ABI = [ + {'name': 'smartSwapByOrderId', 'inputs': [ + _p('orderId', 'uint256'), + _p('baseRequest', 'tuple', _BR), + _p('batchesAmount', 'uint256[]'), + _p('batches', 'tuple[][]', _RP), + _p('extraData', 'tuple[]', _ST), + ]}, + {'name': 'unxswapByOrderId', 'inputs': [ + _p('srcToken', 'uint256'), + _p('amount', 'uint256'), + _p('minReturn', 'uint256'), + _p('pools', 'bytes32[]'), + ]}, + {'name': 'smartSwapByInvest', 'inputs': [ + _p('baseRequest', 'tuple', _BR), + _p('batchesAmount', 'uint256[]'), + _p('batches', 'tuple[][]', _RP), + _p('extraData', 'tuple[]', _ST), + _p('to', 'address'), + ]}, + {'name': 'smartSwapByInvestWithRefund', 'inputs': [ + _p('baseRequest', 'tuple', _BR), + _p('batchesAmount', 'uint256[]'), + _p('batches', 'tuple[][]', _RP), + _p('extraData', 'tuple[]', _ST), + _p('to', 'address'), + _p('refundTo', 'address'), + ]}, + {'name': 'uniswapV3SwapTo', 'inputs': [ + _p('receiver', 'uint256'), + _p('amount', 'uint256'), + _p('minReturn', 'uint256'), + _p('pools', 'uint256[]'), + ]}, + {'name': 'smartSwapTo', 'inputs': [ + _p('orderId', 'uint256'), + _p('receiver', 'address'), + _p('baseRequest', 'tuple', _BR), + _p('batchesAmount', 'uint256[]'), + _p('batches', 'tuple[][]', _RP), + _p('extraData', 'tuple[]', _ST), + ]}, + {'name': 'unxswapTo', 'inputs': [ + _p('srcToken', 'uint256'), + _p('amount', 'uint256'), + _p('minReturn', 'uint256'), + _p('receiver', 'address'), + _p('pools', 'bytes32[]'), + ]}, + {'name': 'uniswapV3SwapToWithBaseRequest', 'inputs': [ + _p('orderId', 'uint256'), + _p('receiver', 'address'), + _p('baseRequest', 'tuple', _BR), + _p('pools', 'uint256[]'), + ]}, + {'name': 'unxswapToWithBaseRequest', 'inputs': [ + _p('orderId', 'uint256'), + _p('receiver', 'address'), + _p('baseRequest', 'tuple', _BR), + _p('pools', 'bytes32[]'), + ]}, + {'name': 'swapWrap', 'inputs': [ + _p('orderId', 'uint256'), + _p('rawdata', 'uint256'), + ]}, + {'name': 'swapWrapToWithBaseRequest', 'inputs': [ + _p('orderId', 'uint256'), + _p('receiver', 'address'), + _p('baseRequest', 'tuple', _BR), + ]}, + {'name': 'dagSwapByOrderId', 'inputs': [ + _p('orderId', 'uint256'), + _p('baseRequest', 'tuple', _BR), + _p('paths', 'tuple[]', _RP), + ]}, + {'name': 'dagSwapTo', 'inputs': [ + _p('orderId', 'uint256'), + _p('receiver', 'address'), + _p('baseRequest', 'tuple', _BR), + _p('paths', 'tuple[]', _RP), + ]}, + {'name': 'approve', 'inputs': [ + _p('spender', 'address'), + _p('amount', 'uint256'), + ]}, +] + +# ============================================================================ +# Selector computation +# ============================================================================ + +def _canonical_type_str(typ: str, components=None) -> str: + if 'tuple' in typ: + suffix = typ[len('tuple'):] + inner = ','.join(_canonical_type_str(c['type'], c.get('components')) for c in components) + return f'({inner}){suffix}' + return typ + +def _compute_selector(func_def: dict) -> str: + parts = [_canonical_type_str(inp['type'], inp.get('components')) for inp in func_def['inputs']] + sig = f"{func_def['name']}({','.join(parts)})" + return '0x' + keccak(text=sig).hex()[:8] + +_SELECTOR_MAP: dict = {} +for _fn in _ABI: + _SELECTOR_MAP[_compute_selector(_fn)] = _fn + +# ============================================================================ +# Low-level helpers (formatters/formatters.js) +# ============================================================================ + +def to_checksum_address(raw_hex: str) -> str: + try: + return _eth_checksum(raw_hex) + except Exception: + return raw_hex + +def _to_int(value) -> int: + if isinstance(value, int): + return value + if isinstance(value, bytes): + return int.from_bytes(value, 'big') + if isinstance(value, str): + return int(value, 16) if value.startswith('0x') else int(value) + return int(value) + +def get_value(value): + if isinstance(value, bool): + return value + if isinstance(value, int): + return str(value) + if isinstance(value, bytes): + return '0x' + value.hex() + if isinstance(value, str): + if value.startswith('0x') and len(value) == 42: + return to_checksum_address(value) + return value + if isinstance(value, (list, tuple)): + return [get_value(item) for item in value] + return value + +def bytes32_to_address(param) -> str: + if param is None: + return '0x' + '0' * 40 + try: + addr = _to_int(param) & ADDRESS_MASK + return to_checksum_address('0x' + format(addr, '040x')) + except Exception: + return '0x' + '0' * 40 + +# ============================================================================ +# Type checkers (core/type_checkers.js) +# ============================================================================ + +def is_base_request_tuple(inp: dict, value) -> bool: + if inp.get('type') != 'tuple': + return False + comps = inp.get('components', []) + return (len(comps) == 5 and + [c['type'] for c in comps] == ['uint256', 'address', 'uint256', 'uint256', 'uint256'] and + isinstance(value, (list, tuple)) and len(value) == 5) + +def is_router_path_array(inp: dict, value) -> bool: + if inp.get('type') not in ('tuple[][]', 'tuple[]'): + return False + comps = inp.get('components', []) + return (len(comps) == 5 and + [c['type'] for c in comps] == ['address[]', 'address[]', 'uint256[]', 'bytes[]', 'uint256'] and + isinstance(value, (list, tuple))) + +def is_router_path_tuple(item) -> bool: + return isinstance(item, (list, tuple)) and len(item) == 5 + +def is_packed_receiver(inp: dict, param_name: str) -> bool: + return inp.get('type') == 'uint256' and param_name == 'receiver' + +def is_pools_array(inp: dict, param_name: str) -> bool: + return inp.get('type') in ('uint256[]', 'bytes32[]') and param_name == 'pools' + +def is_packed_src_token(inp: dict, param_name: str) -> bool: + return inp.get('type') == 'uint256' and param_name == 'srcToken' + +def is_swap_wrap_rawdata(inp: dict, param_name: str) -> bool: + return inp.get('type') == 'uint256' and param_name == 'rawdata' + +def is_from_token_with_mode(inp: dict, param_name: str, function_name: str) -> bool: + if inp.get('type') != 'uint256' or param_name != 'fromToken': + return False + return bool(function_name) and ( + function_name.startswith('dagSwap') or + function_name.startswith('smartSwap') or + function_name in ('smartSwapByInvest', 'smartSwapByInvestWithRefund') + ) + +# ============================================================================ +# Formatters (formatters/formatters.js) +# ============================================================================ + +def format_base_request(arr, function_name=None) -> dict: + from_token, to_token, from_token_amount, min_return_amount, dead_line = arr + return { + 'fromToken': bytes32_to_address(from_token), + 'toToken': get_value(to_token), + 'fromTokenAmount': get_value(from_token_amount), + 'minReturnAmount': get_value(min_return_amount), + 'deadLine': get_value(dead_line), + } + +def format_router_path_array(arr, function_name=None): + is_dag = function_name and function_name.startswith('dagSwap') + if is_dag: + return [format_router_path(rp, function_name) if is_router_path_tuple(rp) else get_value(rp) for rp in arr] + return [[format_router_path(rp, function_name) if is_router_path_tuple(rp) else get_value(rp) for rp in batch] for batch in arr] + +def format_router_path(arr, function_name=None) -> dict: + mix_adapters, asset_to, raw_data, extra_data, from_token = arr + supports_mode = function_name and ( + function_name.startswith('dagSwap') or + function_name.startswith('smartSwap') or + function_name in ('smartSwapByInvest', 'smartSwapByInvestWithRefund') + ) + return { + 'mixAdapters': get_value(mix_adapters), + 'assetTo': get_value(asset_to), + 'rawData': _decode_raw_data_array(raw_data, function_name), + 'extraData': get_value(extra_data), + 'fromToken': unpack_from_token_with_mode(from_token) if supports_mode else get_value(from_token), + } + +def _decode_raw_data_array(arr, function_name=None): + if not isinstance(arr, (list, tuple)): + return get_value(arr) + is_dag = function_name and function_name.startswith('dagSwap') + return [unpack_dag_raw_data(item) if is_dag else unpack_raw_data(item) for item in arr] + +def unpack_raw_data(v) -> dict: + try: + n = _to_int(v) + return { + 'poolAddress': to_checksum_address('0x' + format(n & ADDRESS_MASK, '040x')), + 'reverse': bool(n & REVERSE_MASK), + 'weight': str((n & WEIGHT_MASK) >> 160), + } + except Exception as e: + return {'original': get_value(v), 'error': f'Failed to unpack rawData: {e}'} + +def unpack_dag_raw_data(v) -> dict: + try: + n = _to_int(v) + return { + 'poolAddress': to_checksum_address('0x' + format(n & ADDRESS_MASK, '040x')), + 'reverse': bool(n & REVERSE_MASK), + 'weight': str((n & WEIGHT_MASK) >> 160), + 'inputIndex': str((n & DAG_INPUT_INDEX_MASK) >> 184), + 'outputIndex': str((n & DAG_OUTPUT_INDEX_MASK) >> 176), + } + except Exception as e: + return {'original': get_value(v), 'error': f'Failed to unpack DAG rawData: {e}'} + +def unpack_receiver(v) -> dict: + try: + n = _to_int(v) + return { + 'orderId': str((n & ORDER_ID_MASK) >> 160), + 'address': to_checksum_address('0x' + format(n & ADDRESS_MASK, '040x')), + } + except Exception as e: + return {'original': get_value(v), 'error': f'Failed to unpack receiver: {e}'} + +def unpack_pools_array(arr, function_name: str) -> list: + if not isinstance(arr, (list, tuple)) or len(arr) == 0: + return get_value(arr) + is_unxswap = function_name and function_name.startswith('unxswap') + return [unpack_unxswap_pool(pool) if is_unxswap else unpack_uniswap_v3_pool(pool) for pool in arr] + +def unpack_unxswap_pool(v) -> dict: + try: + n = _to_int(v) + return { + 'isToken0Tax': bool(n & IS_TOKEN0_TAX_MASK), + 'isToken1Tax': bool(n & IS_TOKEN1_TAX_MASK), + 'WETH': bool(n & WETH_MASK), + 'isOneForZero': bool(n & ONE_FOR_ZERO_MASK), + 'numerator': str((n & NUMERATOR_MASK) >> 160), + 'address': to_checksum_address('0x' + format(n & ADDRESS_MASK, '040x')), + } + except Exception as e: + return {'original': get_value(v), 'error': f'Failed to unpack unxswap pool: {e}'} + +def unpack_uniswap_v3_pool(v) -> dict: + try: + n = _to_int(v) + return { + 'isOneForZero': bool(n & ONE_FOR_ZERO_MASK), + 'wethUnwrap': bool(n & WETH_UNWRAP_MASK), + 'pool': to_checksum_address('0x' + format(n & ADDRESS_MASK, '040x')), + } + except Exception as e: + return {'original': get_value(v), 'error': f'Failed to unpack uniswapV3 pool: {e}'} + +def unpack_src_token(v) -> dict: + try: + n = _to_int(v) + return { + 'orderId': str(n >> 160), + 'address': to_checksum_address('0x' + format(n & ADDRESS_MASK, '040x')), + } + except Exception as e: + return {'original': get_value(v), 'error': f'Failed to unpack srcToken: {e}'} + +def unpack_swap_rawdata(v) -> dict: + try: + n = _to_int(v) + return { + 'reversed': bool(n & REVERSE_MASK), + 'amount': str(n & SWAP_AMOUNT_MASK), + } + except Exception as e: + return {'original': get_value(v), 'error': f'Failed to unpack swapWrap rawdata: {e}'} + +def unpack_from_token_with_mode(v) -> dict: + try: + n = _to_int(v) + if n & MODE_NO_TRANSFER_MASK: + flag = 'NO_TRANSFER' + elif n & MODE_BY_INVEST_MASK: + flag = 'BY_INVEST' + elif n & MODE_PERMIT2_MASK: + flag = 'PERMIT2' + else: + flag = 'DEFAULT' + return { + 'address': to_checksum_address('0x' + format(n & ADDRESS_MASK, '040x')), + 'flag': flag, + } + except Exception as e: + return {'original': get_value(v), 'error': f'Failed to unpack fromToken with mode: {e}'} + +# ============================================================================ +# decode_functions logic (decode_functions.js) +# ============================================================================ + +def _create_named_parameters(inputs: list, decoded_params: tuple, function_name: str) -> dict: + named = {} + for i, inp in enumerate(inputs): + param_name = inp.get('name') or f'param{i}' + value = get_value(decoded_params[i]) + + if is_base_request_tuple(inp, value): + value = format_base_request(value, function_name) + elif is_router_path_array(inp, value): + value = format_router_path_array(value, function_name) + elif is_packed_receiver(inp, param_name): + value = unpack_receiver(value) + elif is_pools_array(inp, param_name): + value = unpack_pools_array(value, function_name) + elif is_packed_src_token(inp, param_name): + value = unpack_src_token(value) + elif is_swap_wrap_rawdata(inp, param_name): + value = unpack_swap_rawdata(value) + elif is_from_token_with_mode(inp, param_name, function_name): + value = unpack_from_token_with_mode(value) + + named[param_name] = value + return named + +def decode_functions(calldata: str) -> dict: + try: + if not calldata or not isinstance(calldata, str): + return {'error': 'Invalid calldata input'} + if not calldata.startswith('0x'): + calldata = '0x' + calldata + if len(calldata) < 10: + return {'error': 'calldata length is too short'} + + selector = calldata[:10].lower() + func_def = _SELECTOR_MAP.get(selector) + if not func_def: + return {'error': f'Unknown function selector: {selector}', 'selector': selector} + + eth_types = [_canonical_type_str(inp['type'], inp.get('components')) for inp in func_def['inputs']] + decoded = abi_decode(eth_types, bytes.fromhex(calldata[2:])[4:]) + named = _create_named_parameters(func_def['inputs'], decoded, func_def['name']) + + return {'function': {'name': func_def['name'], 'selector': selector}, **named} + + except Exception as e: + return {'error': f'Decoding failed: {e}', 'originalError': str(e)} + +# ============================================================================ +# Commission constants + helpers (decode_fee.py / decode_commission.py) +# ============================================================================ + +_CBYTE = {'FLAG': 12, 'RATE': 12, 'ADDRESS': 40, 'BLOCK': 64} +_FLAG_PREFIXES = {'SINGLE': '0x3ca2', 'DUAL': '0x2222', 'MULTIPLE': '0x8888'} +_VALID_FLAGS = [ + '0x3ca20afc2aaa', '0x3ca20afc2bbb', + '0x22220afc2aaa', '0x22220afc2bbb', + '0x88880afc2aaa', '0x88880afc2bbb', +] +_MIN_REFERRERS = 3 +_MAX_REFERRERS = 8 +_ORDINALS = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth'] + +def _commission_type(flag: str) -> str: + f = flag.lower() + for kind, prefix in _FLAG_PREFIXES.items(): + if f.startswith(prefix): + amount = kind + break + else: + amount = 'UNKNOWN' + token = 'FROM_TOKEN_COMMISSION' if f.endswith('aaa') else 'TO_TOKEN_COMMISSION' + return f'{amount}_{token}' + +def _parse_commission(hex32: str) -> dict: + h = hex32.lower().removeprefix('0x') + flag = '0x' + h[:_CBYTE['FLAG']] + if not any(f.lower() == flag for f in _VALID_FLAGS): + raise ValueError(f'Invalid commission flag: {flag}') + rate = int(h[_CBYTE['FLAG']:_CBYTE['FLAG'] + _CBYTE['RATE']], 16) + a = _CBYTE['FLAG'] + _CBYTE['RATE'] + return {'flag': flag, 'commissionType': _commission_type(flag), 'rate': str(rate), 'address': '0x' + h[a:a + _CBYTE['ADDRESS']]} + +def _parse_middle(hex32: str) -> dict: + h = hex32.lower().removeprefix('0x') + return {'isToB': h[:2] == '80', 'token': '0x' + h[24:]} + +def _parse_referrer_num(hex32: str) -> int: + return int(hex32.lower().removeprefix('0x')[2:4], 16) + +def _extract_blocks(calldata_hex: str, flag_hex: str, count: int): + idx = calldata_hex.find(flag_hex) + if idx == -1 or len(calldata_hex) < idx + _CBYTE['BLOCK'] * count: + return None + return {'flagStart': idx, 'blocks': ['0x' + calldata_hex[idx + i * _CBYTE['BLOCK']:idx + (i + 1) * _CBYTE['BLOCK']] for i in range(count)]} + +# ============================================================================ +# Trim constants + helpers (decode_fee.py / decode_trim.py) +# ============================================================================ + +_TRIM_FLAGS = {'SINGLE': '777777771111', 'DUAL': '777777772222'} +_BLOCK = 64 + +def _parse_trim_data(hex32: str) -> dict: + h = hex32.lower().removeprefix('0x') + flag = '0x' + h[:12] + valid = ['0x' + v for v in _TRIM_FLAGS.values()] + if flag not in valid: + raise ValueError(f'Invalid trim flag: {flag}') + return {'flag': flag, 'rate': str(int(h[12:24], 16)), 'address': '0x' + h[24:64]} + +def _parse_expect_amount(hex32: str) -> dict: + h = hex32.lower().removeprefix('0x') + flag = '0x' + h[:12] + valid = ['0x' + v for v in _TRIM_FLAGS.values()] + if flag not in valid: + raise ValueError(f'Invalid trim flag in expect amount block: {flag}') + return {'expectAmount': str(int(h[24:64], 16)), 'trimType': 'toB' if h[12:14] == '80' else 'toC'} + +# ============================================================================ +# Fee extraction (decode_fee.py) +# ============================================================================ + +def extract_commission_info(calldata_hex: str) -> dict: + c = calldata_hex.lower().removeprefix('0x') + + for flag in ['0x3ca20afc2aaa', '0x3ca20afc2bbb']: + idx = c.find(flag[2:]) + if idx != -1 and idx >= _CBYTE['BLOCK']: + try: + return { + 'hasCommission': True, 'referCount': 1, + 'middle': _parse_middle('0x' + c[idx - _CBYTE['BLOCK']:idx]), + 'first': _parse_commission('0x' + c[idx:idx + _CBYTE['BLOCK']]), + } + except Exception: + pass + + for flag in ['0x22220afc2aaa', '0x22220afc2bbb']: + r = _extract_blocks(c, flag[2:], 3) + if r: + try: + first, middle, last = r['blocks'] + return {'hasCommission': True, 'referCount': 2, + 'first': _parse_commission(first), 'middle': _parse_middle(middle), 'last': _parse_commission(last)} + except Exception: + pass + + for flag in ['0x88880afc2aaa', '0x88880afc2bbb']: + fh = flag[2:] + if c.find(fh) == -1 or len(c) < _CBYTE['BLOCK'] * 4: + continue + try: + ms = len(c) - _CBYTE['BLOCK'] * 2 + ref_num = _parse_referrer_num('0x' + c[ms:ms + _CBYTE['BLOCK']]) + if not (_MIN_REFERRERS <= ref_num <= _MAX_REFERRERS): + continue + total = ref_num + 1 + if len(c) < _CBYTE['BLOCK'] * total: + continue + r = _extract_blocks(c, fh, total) + if not r: + continue + mid_i, c1_i = ref_num - 1, ref_num + ret = {'hasCommission': True, 'referCount': ref_num, + _ORDINALS[0]: _parse_commission(r['blocks'][0]), + 'middle': _parse_middle(r['blocks'][mid_i])} + for i in range(1, mid_i): + ret[_ORDINALS[i]] = _parse_commission(r['blocks'][i]) + ret[_ORDINALS[mid_i]] = _parse_commission(r['blocks'][c1_i]) + return ret + except Exception: + pass + + return {'hasCommission': False} + +def extract_trim_info(calldata_hex: str) -> dict: + c = calldata_hex.lower().removeprefix('0x') + + single_flag = _TRIM_FLAGS['SINGLE'] + idx = c.find(single_flag) + if idx != -1: + last_idx, search = idx, idx + 1 + while True: + nxt = c.find(single_flag, search) + if nxt == -1: + break + last_idx, search = nxt, nxt + 1 + if last_idx >= _BLOCK: + try: + td = _parse_trim_data('0x' + c[last_idx:last_idx + _BLOCK]) + ea = _parse_expect_amount('0x' + c[last_idx - _BLOCK:last_idx]) + return { + 'hasTrim': ea['trimType'], 'trimRate': td['rate'], 'trimAddress': td['address'], + 'expectAmountOut': ea['expectAmount'], 'chargeRate': '0', + 'chargeAddress': '0x0000000000000000000000000000000000000000', + } + except Exception: + pass + + dual_flag = _TRIM_FLAGS['DUAL'] + positions, search = [], 0 + while True: + pos = c.find(dual_flag, search) + if pos == -1: + break + positions.append(pos) + search = pos + 1 + if len(positions) >= 3: + fs = positions[-1] + if fs >= _BLOCK * 2: + try: + td1 = _parse_trim_data('0x' + c[fs:fs + _BLOCK]) + ea = _parse_expect_amount('0x' + c[fs - _BLOCK:fs]) + td2 = _parse_trim_data('0x' + c[fs - _BLOCK * 2:fs - _BLOCK]) + return { + 'hasTrim': ea['trimType'], 'trimRate': td1['rate'], 'trimAddress': td1['address'], + 'expectAmountOut': ea['expectAmount'], 'chargeRate': td2['rate'], 'chargeAddress': td2['address'], + } + except Exception: + pass + + return {'hasTrim': False} + +# ============================================================================ +# Unified decode entry point +# ============================================================================ + +def decode(calldata: str) -> dict: + """ + Decode DEX Router calldata: function parameters + commission + trim fee. + + Returns a dict with: + - 'function': { name, selector } + - all named parameters (flattened at top level) + - commission fields flattened at top level + - 'hasTrim' and trim fields flattened at top level + """ + result = decode_functions(calldata) + result.update(extract_commission_info(calldata)) + result.update(extract_trim_info(calldata)) + return result + +# ============================================================================ +# CLI +# ============================================================================ + +if __name__ == '__main__': + if len(sys.argv) < 2: + print('Usage: python decode.py ', file=sys.stderr) + sys.exit(1) + + print(json.dumps(decode(sys.argv[1]), indent=2, default=str)) diff --git a/scripts/encode.py b/scripts/encode.py new file mode 100644 index 0000000..c485be2 --- /dev/null +++ b/scripts/encode.py @@ -0,0 +1,647 @@ +""" +encode.py — Standalone OKX DEX Router encoder. + +Integrates encode_functions.py and encode_fee.py into a single file. +Encodes function calldata, then optionally appends commission and/or trim fee data. + +Requires: pip install eth-abi "eth-hash[pycryptodome]" + +Public API: + encode_functions(json_data) → calldata hex + add_commission_to_calldata(calldata, commission_data) → calldata hex + add_trim_to_calldata(calldata, trim_data) → calldata hex + add_fee_to_calldata(calldata, commission_data, trim_data) → calldata hex + encode(json_data, commission_data, trim_data) → calldata hex (all-in-one) + +CLI: + python encode.py [commission_json] [trim_json] +""" + +try: + from eth_abi import encode as _abi_encode + from eth_hash.auto import keccak as _keccak +except ImportError: + raise ImportError( + 'Missing dependencies. Run: pip install eth-abi "eth-hash[pycryptodome]"' + ) + +# ============================================================================ +# Masks (from core/masks.js) +# ============================================================================ + +_ONE_FOR_ZERO_MASK = 0x8000000000000000000000000000000000000000000000000000000000000000 +_WETH_UNWRAP_MASK = 0x2000000000000000000000000000000000000000000000000000000000000000 +_REVERSE_MASK = 0x8000000000000000000000000000000000000000000000000000000000000000 +_IS_TOKEN0_TAX_MASK = 0x1000000000000000000000000000000000000000000000000000000000000000 +_IS_TOKEN1_TAX_MASK = 0x2000000000000000000000000000000000000000000000000000000000000000 +_WETH_MASK = 0x4000000000000000000000000000000000000000000000000000000000000000 +_SWAP_AMOUNT_MASK = 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff + +_MODE_NO_TRANSFER = 1 << 251 +_MODE_BY_INVEST = 1 << 250 +_MODE_PERMIT2 = 1 << 249 + +# ============================================================================ +# Commission constants (from encode_commission.js) +# ============================================================================ + +_COMMISSION_BYTE_SIZE = {'FLAG': 12, 'RATE': 12, 'ADDRESS': 40, 'BLOCK': 64} +_PADDING = '00' * 10 +_ORDINAL_NAMES = ['first', 'second', 'third', 'fourth', 'fifth', 'sixth', 'seventh', 'eighth'] +_MIN_COMMISSION_COUNT = 1 +_MAX_COMMISSION_COUNT = 8 + +# ============================================================================ +# Trim constants (from encode_trim.js) +# ============================================================================ + +_TRIM_FLAGS = {'SINGLE': '777777771111', 'DUAL': '777777772222'} +_IS_TOB_TRIM = {'TOB': '80', 'TOC': '00'} + + +# ============================================================================ +# Shared utilities +# ============================================================================ + +def _to_int(x) -> int: + if isinstance(x, int): + return x + s = str(x) + return int(s, 16) if s.startswith(('0x', '0X')) else int(s) + + +def _to_bytes(x, size: int = None) -> bytes: + if isinstance(x, bytes): + b = x + elif isinstance(x, str): + h = x.removeprefix('0x') + if size: + h = h.zfill(size * 2) + b = bytes.fromhex(h) + else: + b = bytes(x) + if size and len(b) < size: + b = b.rjust(size, b'\x00') + return b + + +def _normalize_hex(hex_str: str, length: int) -> str: + return hex_str.removeprefix('0x').lower().zfill(length) + + +# ============================================================================ +# Packers (from encode_packers.js) +# ============================================================================ + +def pack_src_token(src_token) -> int: + if isinstance(src_token, (str, int)): + return _to_int(src_token) + return (_to_int(src_token['orderId']) << 160) + _to_int(src_token['address']) + + +def pack_receiver(receiver) -> int: + if isinstance(receiver, (str, int)): + return _to_int(receiver) + packed = _to_int(receiver['address']) + order_id = receiver.get('orderId') + if order_id is not None: + packed |= _to_int(order_id) << 160 + return packed + + +def pack_rawdata(rawdata) -> int: + if isinstance(rawdata, (str, int)): + return _to_int(rawdata) + packed = _to_int(rawdata['amount']) & _SWAP_AMOUNT_MASK + if rawdata.get('reversed'): + packed |= _REVERSE_MASK + return packed + + +def pack_uniswap_v3_pool(pool) -> int: + if isinstance(pool, (str, int)): + return _to_int(pool) + packed = _to_int(pool['pool']) + if pool.get('isOneForZero'): + packed |= _ONE_FOR_ZERO_MASK + if pool.get('wethUnwrap'): + packed |= _WETH_UNWRAP_MASK + return packed + + +def pack_unxswap_pool(pool) -> bytes: + if isinstance(pool, (str, bytes)): + return _to_bytes(pool, 32) + packed = _to_int(pool['address']) + if pool.get('isToken0Tax'): + packed |= _IS_TOKEN0_TAX_MASK + if pool.get('isToken1Tax'): + packed |= _IS_TOKEN1_TAX_MASK + if pool.get('WETH'): + packed |= _WETH_MASK + if pool.get('isOneForZero'): + packed |= _ONE_FOR_ZERO_MASK + packed |= (_to_int(pool.get('numerator', 0)) & 0xFFFFFFFF) << 160 + return _to_bytes(hex(packed), 32) + + +def pack_dag_raw_data(raw_data) -> int: + if isinstance(raw_data, (str, int)): + return _to_int(raw_data) + packed = _to_int(raw_data['poolAddress']) + if raw_data.get('weight') is not None: + packed |= (_to_int(raw_data['weight']) & 0xFFFF) << 160 + if raw_data.get('outputIndex') is not None: + packed |= (_to_int(raw_data['outputIndex']) & 0xFF) << 176 + if raw_data.get('inputIndex') is not None: + packed |= (_to_int(raw_data['inputIndex']) & 0xFF) << 184 + if raw_data.get('reverse'): + packed |= _REVERSE_MASK + return packed + + +def _pack_raw_data_array(raw_data_array: list) -> list: + if not isinstance(raw_data_array, list): + return raw_data_array + result = [] + for rd in raw_data_array: + if isinstance(rd, (str, int)): + result.append(_to_int(rd)) + else: + packed = _to_int(rd['poolAddress']) + if rd.get('reverse'): + packed |= _REVERSE_MASK + if rd.get('weight') is not None: + packed |= (_to_int(rd['weight']) & 0xFFFF) << 160 + result.append(packed) + return result + + +def _pack_dag_raw_data_array(raw_data_array: list) -> list: + if not isinstance(raw_data_array, list): + return raw_data_array + return [pack_dag_raw_data(rd) for rd in raw_data_array] + + +# ============================================================================ +# Helpers (from encode_helpers.js) +# ============================================================================ + +def _get_mode_by_name(flag_name: str) -> int: + return {'NO_TRANSFER': _MODE_NO_TRANSFER, 'BY_INVEST': _MODE_BY_INVEST, 'PERMIT2': _MODE_PERMIT2}.get(flag_name, 0) + + +def _process_from_token_with_mode(from_token) -> int: + if isinstance(from_token, dict): + address = from_token.get('address', 0) + flag = from_token.get('flag', 0) + if isinstance(flag, str): + flag = _get_mode_by_name(flag) + return _to_int(str(address)) | int(flag) + return _to_int(str(from_token)) + + +def _prepare_base_request_tuple(base_request: dict, function_name: str = None, order_id=None) -> tuple: + if not base_request: + raise ValueError('Missing baseRequest parameter') + from_token = base_request['fromToken'] + if function_name == 'unxswapToWithBaseRequest' and order_id: + from_token = (_to_int(str(order_id)) << 160) | _to_int(str(from_token)) + else: + from_token = _to_int(str(from_token)) + return ( + from_token, + base_request['toToken'], + _to_int(str(base_request['fromTokenAmount'])), + _to_int(str(base_request['minReturnAmount'])), + _to_int(str(base_request['deadLine'])), + ) + + +def _prepare_batches_tuples(batches: list) -> list: + if not isinstance(batches, list): + raise ValueError('Batches must be an array') + return [ + [ + ( + list(rp['mixAdapters']), + list(rp['assetTo']), + _pack_raw_data_array(rp['rawData']), + [_to_bytes(ed) for ed in rp.get('extraData', [])], + _process_from_token_with_mode(rp['fromToken']), + ) + for rp in batch + ] + for batch in batches + ] + + +def _prepare_dag_paths_tuples(paths: list) -> list: + if not isinstance(paths, list): + raise ValueError('DAG paths must be an array') + return [ + ( + list(rp['mixAdapters']), + list(rp['assetTo']), + _pack_dag_raw_data_array(rp['rawData']), + [_to_bytes(ed) for ed in rp.get('extraData', [])], + _process_from_token_with_mode(rp['fromToken']), + ) + for rp in paths + ] + + +# ============================================================================ +# Parameter preparation (from encode_parameters.js) +# ============================================================================ + +def _prep_extra_data(extra_data_list: list) -> list: + return [ + ( + _to_int(str(ed['fromToken'])), + ed['toToken'], + ed['receiver'], + ed['payer'], + _to_int(str(ed['fromTokenAmount'])), + _to_int(str(ed['minReturnAmount'])), + _to_int(str(ed['deadLine'])), + _to_int(str(ed['orderId'])), + bool(ed['isToB']), + _to_bytes(ed.get('settlerData', '0x')), + ) + for ed in extra_data_list + ] + + +def _prepare_smart_swap_by_order_id_params(d): + if not all([d.get('orderId'), d.get('baseRequest'), d.get('batchesAmount'), d.get('batches')]): + raise ValueError('Missing required parameters for smartSwapByOrderId') + return [_to_int(str(d['orderId'])), _prepare_base_request_tuple(d['baseRequest']), [_to_int(str(x)) for x in d['batchesAmount']], _prepare_batches_tuples(d['batches']), _prep_extra_data(d.get('extraData', []))] + + +def _prepare_smart_swap_by_invest_params(d): + if not all([d.get('baseRequest'), d.get('batchesAmount'), d.get('batches'), d.get('to')]): + raise ValueError('Missing required parameters for smartSwapByInvest') + return [_prepare_base_request_tuple(d['baseRequest']), [_to_int(str(x)) for x in d['batchesAmount']], _prepare_batches_tuples(d['batches']), _prep_extra_data(d.get('extraData', [])), d['to']] + + +def _prepare_smart_swap_by_invest_with_refund_params(d): + if not all([d.get('baseRequest'), d.get('batchesAmount'), d.get('batches'), d.get('to'), d.get('refundTo')]): + raise ValueError('Missing required parameters for smartSwapByInvestWithRefund') + return [_prepare_base_request_tuple(d['baseRequest']), [_to_int(str(x)) for x in d['batchesAmount']], _prepare_batches_tuples(d['batches']), _prep_extra_data(d.get('extraData', [])), d['to'], d['refundTo']] + + +def _prepare_uniswap_v3_swap_to_params(d): + if not all([d.get('receiver'), d.get('amount'), d.get('minReturn'), d.get('pools')]): + raise ValueError('Missing required parameters for uniswapV3SwapTo') + receiver = d['receiver'] + if isinstance(receiver, str): + receiver_obj = {'orderId': d.get('orderId', '0'), 'address': receiver} + else: + receiver_obj = {'orderId': d.get('orderId', receiver.get('orderId', '0')), 'address': receiver['address']} + return [pack_receiver(receiver_obj), _to_int(str(d['amount'])), _to_int(str(d['minReturn'])), [pack_uniswap_v3_pool(p) for p in d['pools']]] + + +def _prepare_smart_swap_to_params(d): + if not all([d.get('orderId'), d.get('receiver'), d.get('baseRequest'), d.get('batchesAmount'), d.get('batches')]): + raise ValueError('Missing required parameters for smartSwapTo') + return [_to_int(str(d['orderId'])), d['receiver'], _prepare_base_request_tuple(d['baseRequest']), [_to_int(str(x)) for x in d['batchesAmount']], _prepare_batches_tuples(d['batches']), _prep_extra_data(d.get('extraData', []))] + + +def _prepare_unxswap_by_order_id_params(d): + if not all([d.get('srcToken'), d.get('amount'), d.get('minReturn'), d.get('pools')]): + raise ValueError('Missing required parameters for unxswapByOrderId') + return [pack_src_token({'orderId': d.get('orderId', '0'), 'address': d['srcToken']}), _to_int(str(d['amount'])), _to_int(str(d['minReturn'])), [pack_unxswap_pool(p) for p in d['pools']]] + + +def _prepare_unxswap_to_params(d): + if not all([d.get('srcToken'), d.get('amount'), d.get('minReturn'), d.get('receiver'), d.get('pools')]): + raise ValueError('Missing required parameters for unxswapTo') + return [pack_src_token({'orderId': d.get('orderId', '0'), 'address': d['srcToken']}), _to_int(str(d['amount'])), _to_int(str(d['minReturn'])), d['receiver'], [pack_unxswap_pool(p) for p in d['pools']]] + + +def _prepare_uniswap_v3_swap_to_with_base_request_params(d): + if not all([d.get('orderId'), d.get('receiver'), d.get('baseRequest'), d.get('pools')]): + raise ValueError('Missing required parameters for uniswapV3SwapToWithBaseRequest') + return [_to_int(str(d['orderId'])), d['receiver'], _prepare_base_request_tuple(d['baseRequest'], 'uniswapV3SwapToWithBaseRequest'), [pack_uniswap_v3_pool(p) for p in d['pools']]] + + +def _prepare_unxswap_to_with_base_request_params(d): + if not all([d.get('orderId'), d.get('receiver'), d.get('baseRequest'), d.get('pools')]): + raise ValueError('Missing required parameters for unxswapToWithBaseRequest') + return [_to_int(str(d['orderId'])), d['receiver'], _prepare_base_request_tuple(d['baseRequest'], 'unxswapToWithBaseRequest', d['orderId']), [pack_unxswap_pool(p) for p in d['pools']]] + + +def _prepare_swap_wrap_params(d): + if not all([d.get('orderId'), d.get('rawdata')]): + raise ValueError('Missing required parameters for swapWrap') + return [_to_int(str(d['orderId'])), pack_rawdata(d['rawdata'])] + + +def _prepare_swap_wrap_to_with_base_request_params(d): + if not all([d.get('orderId'), d.get('receiver'), d.get('baseRequest')]): + raise ValueError('Missing required parameters for swapWrapToWithBaseRequest') + return [_to_int(str(d['orderId'])), d['receiver'], _prepare_base_request_tuple(d['baseRequest'])] + + +def _prepare_dag_swap_by_order_id_params(d): + if not all([d.get('orderId'), d.get('baseRequest'), d.get('paths')]): + raise ValueError('Missing required parameters for dagSwapByOrderId') + return [_to_int(str(d['orderId'])), _prepare_base_request_tuple(d['baseRequest']), _prepare_dag_paths_tuples(d['paths'])] + + +def _prepare_dag_swap_to_params(d): + if not all([d.get('orderId'), d.get('receiver'), d.get('baseRequest'), d.get('paths')]): + raise ValueError('Missing required parameters for dagSwapTo') + return [_to_int(str(d['orderId'])), d['receiver'], _prepare_base_request_tuple(d['baseRequest']), _prepare_dag_paths_tuples(d['paths'])] + + +def _prepare_approve_params(d): + import re + spender, amount = d.get('spender'), d.get('amount') + if not all([spender, amount]): + raise ValueError('Missing required parameters for approve: spender and amount are required') + if not re.match(r'^0x[a-fA-F0-9]{40}$', spender): + raise ValueError('Invalid spender address format') + return [spender, _to_int(str(amount))] + + +# ============================================================================ +# ABI type definitions (from core/abi.js) +# ============================================================================ + +_BR = '(uint256,address,uint256,uint256,uint256)' +_RP = '(address[],address[],uint256[],bytes[],uint256)' +_ED = '(uint256,address,address,address,uint256,uint256,uint256,uint256,bool,bytes)' + +_FUNC_SPECS = { + 'smartSwapByOrderId': {'sig': f'smartSwapByOrderId(uint256,{_BR},uint256[],{_RP}[][],{_ED}[])', 'types': ['uint256', _BR, 'uint256[]', f'{_RP}[][]', f'{_ED}[]']}, + 'smartSwapByInvest': {'sig': f'smartSwapByInvest({_BR},uint256[],{_RP}[][],{_ED}[],address)', 'types': [_BR, 'uint256[]', f'{_RP}[][]', f'{_ED}[]', 'address']}, + 'smartSwapByInvestWithRefund': {'sig': f'smartSwapByInvestWithRefund({_BR},uint256[],{_RP}[][],{_ED}[],address,address)', 'types': [_BR, 'uint256[]', f'{_RP}[][]', f'{_ED}[]', 'address', 'address']}, + 'uniswapV3SwapTo': {'sig': 'uniswapV3SwapTo(uint256,uint256,uint256,uint256[])', 'types': ['uint256', 'uint256', 'uint256', 'uint256[]']}, + 'smartSwapTo': {'sig': f'smartSwapTo(uint256,address,{_BR},uint256[],{_RP}[][],{_ED}[])', 'types': ['uint256', 'address', _BR, 'uint256[]', f'{_RP}[][]', f'{_ED}[]']}, + 'unxswapByOrderId': {'sig': 'unxswapByOrderId(uint256,uint256,uint256,bytes32[])', 'types': ['uint256', 'uint256', 'uint256', 'bytes32[]']}, + 'unxswapTo': {'sig': 'unxswapTo(uint256,uint256,uint256,address,bytes32[])', 'types': ['uint256', 'uint256', 'uint256', 'address', 'bytes32[]']}, + 'uniswapV3SwapToWithBaseRequest': {'sig': f'uniswapV3SwapToWithBaseRequest(uint256,address,{_BR},uint256[])', 'types': ['uint256', 'address', _BR, 'uint256[]']}, + 'unxswapToWithBaseRequest': {'sig': f'unxswapToWithBaseRequest(uint256,address,{_BR},bytes32[])', 'types': ['uint256', 'address', _BR, 'bytes32[]']}, + 'swapWrap': {'sig': 'swapWrap(uint256,uint256)', 'types': ['uint256', 'uint256']}, + 'swapWrapToWithBaseRequest': {'sig': f'swapWrapToWithBaseRequest(uint256,address,{_BR})', 'types': ['uint256', 'address', _BR]}, + 'dagSwapByOrderId': {'sig': f'dagSwapByOrderId(uint256,{_BR},{_RP}[])', 'types': ['uint256', _BR, f'{_RP}[]']}, + 'dagSwapTo': {'sig': f'dagSwapTo(uint256,address,{_BR},{_RP}[])', 'types': ['uint256', 'address', _BR, f'{_RP}[]']}, + 'approve': {'sig': 'approve(address,uint256)', 'types': ['address', 'uint256']}, +} + +_PREPARE_FN = { + 'smartSwapByOrderId': _prepare_smart_swap_by_order_id_params, + 'smartSwapByInvest': _prepare_smart_swap_by_invest_params, + 'smartSwapByInvestWithRefund': _prepare_smart_swap_by_invest_with_refund_params, + 'uniswapV3SwapTo': _prepare_uniswap_v3_swap_to_params, + 'smartSwapTo': _prepare_smart_swap_to_params, + 'unxswapByOrderId': _prepare_unxswap_by_order_id_params, + 'unxswapTo': _prepare_unxswap_to_params, + 'uniswapV3SwapToWithBaseRequest': _prepare_uniswap_v3_swap_to_with_base_request_params, + 'unxswapToWithBaseRequest': _prepare_unxswap_to_with_base_request_params, + 'swapWrap': _prepare_swap_wrap_params, + 'swapWrapToWithBaseRequest': _prepare_swap_wrap_to_with_base_request_params, + 'dagSwapByOrderId': _prepare_dag_swap_by_order_id_params, + 'dagSwapTo': _prepare_dag_swap_to_params, + 'approve': _prepare_approve_params, +} + + +# ============================================================================ +# Commission encoding (from encode_commission.js) +# ============================================================================ + +def _get_commission_structure(refer_count: int) -> dict: + if not (_MIN_COMMISSION_COUNT <= refer_count <= _MAX_COMMISSION_COUNT): + raise ValueError(f'Invalid referCount: {refer_count}. Must be between {_MIN_COMMISSION_COUNT} and {_MAX_COMMISSION_COUNT}') + if refer_count == 1: + return {'blocks': ['middle', 'first'], 'name': 'SINGLE'} + elif refer_count == 2: + return {'blocks': ['first', 'middle', 'last'], 'name': 'DUAL'} + else: + blocks = [_ORDINAL_NAMES[i] for i in range(refer_count - 1)] + blocks.append('middle') + blocks.append(_ORDINAL_NAMES[refer_count - 1]) + return {'blocks': blocks, 'name': 'MULTIPLE'} + + +def _encode_commission_block(commission: dict) -> str: + if not commission.get('flag') or commission.get('rate') is None or not commission.get('address'): + raise ValueError('Commission block missing required fields: flag, rate, address') + flag = _normalize_hex(commission['flag'], _COMMISSION_BYTE_SIZE['FLAG']) + rate = _normalize_hex(hex(int(str(commission['rate']), 0 if str(commission['rate']).startswith('0x') else 10)), _COMMISSION_BYTE_SIZE['RATE']) + address = _normalize_hex(commission['address'], _COMMISSION_BYTE_SIZE['ADDRESS']) + return flag + rate + address + + +def _encode_middle_block(middle: dict, refer_count: int) -> str: + if not middle.get('token'): + raise ValueError('Middle block missing required field: token') + is_to_b_hex = '80' if middle.get('isToB', middle.get('toB', False)) else '00' + referrer_num_hex = format(refer_count, '02x') if 3 <= refer_count <= 8 else '00' + token = _normalize_hex(middle['token'], _COMMISSION_BYTE_SIZE['ADDRESS']) + return is_to_b_hex + referrer_num_hex + _PADDING + token + + +def _validate_commission_block(commission: dict) -> None: + flag, address = commission.get('flag', ''), commission.get('address', '') + if not flag or commission.get('rate') is None or not address: + raise ValueError(f'Commission blocks must have flag, rate, and address. Missing in: {commission}') + if not flag.startswith('0x') or len(flag) != 14: + raise ValueError(f'Invalid flag format: {flag}') + if not address.startswith('0x') or len(address) != 42: + raise ValueError(f'Invalid address format: {address}') + + +def validate_commission_data(commission_data: dict) -> bool: + if not isinstance(commission_data, dict): + raise ValueError('Commission data must be an object') + refer_count = commission_data.get('referCount') + if not refer_count or not (_MIN_COMMISSION_COUNT <= refer_count <= _MAX_COMMISSION_COUNT): + raise ValueError(f'Commission data must have referCount between {_MIN_COMMISSION_COUNT} and {_MAX_COMMISSION_COUNT}, got: {refer_count}') + if not commission_data.get('middle') or not commission_data.get('first'): + raise ValueError('Commission data must have middle and first properties') + if not commission_data['middle'].get('token'): + raise ValueError('Middle block must have token property') + structure = _get_commission_structure(refer_count) + for block_type in structure['blocks']: + if block_type == 'middle': + continue + if not commission_data.get(block_type): + raise ValueError(f'Commission data with referCount {refer_count} must have {block_type} property') + _validate_commission_block(commission_data[block_type]) + return True + + +def add_commission_to_calldata(calldata: str, commission_data: dict) -> str: + """Append commission encoding to calldata.""" + try: + validate_commission_data(commission_data) + calldata_hex = calldata.removeprefix('0x') + structure = _get_commission_structure(commission_data['referCount']) + encoded_blocks = [] + for block_type in structure['blocks']: + if block_type == 'middle': + encoded_blocks.append(_encode_middle_block(commission_data['middle'], commission_data['referCount'])) + else: + encoded_blocks.append(_encode_commission_block(commission_data[block_type])) + return '0x' + calldata_hex + ''.join(encoded_blocks) + except Exception as e: + raise ValueError(f'Failed to encode commission data: {e}') + + +# ============================================================================ +# Trim encoding (from encode_trim.js) +# ============================================================================ + +def _encode_trim_block(rate, address: str, flag: str) -> str: + if rate is None or not address or not flag: + raise ValueError('Trim block missing required fields: rate, address, flag') + return _normalize_hex(flag, 12) + _normalize_hex(hex(int(str(rate))), 12) + _normalize_hex(address, 40) + + +def _encode_expect_amount_block(expect_amount, flag: str, has_trim: str) -> str: + if expect_amount is None or not flag: + raise ValueError('Expect amount block missing required fields: expectAmount, flag') + is_to_b_hex = _IS_TOB_TRIM['TOB'] if has_trim == 'toB' else _IS_TOB_TRIM['TOC'] + return _normalize_hex(flag, 12) + is_to_b_hex + '00' * 5 + _normalize_hex(hex(int(str(expect_amount))), 40) + + +def _is_valid_charge_rate(r) -> bool: + return r is not None and r not in (0, '0') + + +def _is_valid_charge_address(a) -> bool: + return a is not None and a not in ('0x0000000000000000000000000000000000000000', '0x', '') + + +def validate_trim_data(trim_data: dict) -> bool: + if not isinstance(trim_data, dict): + raise ValueError('Trim data must be an object') + if not trim_data.get('trimRate') or not trim_data.get('trimAddress') or not trim_data.get('expectAmountOut'): + raise ValueError('Trim data must have trimRate, trimAddress, and expectAmountOut properties') + has_trim = trim_data.get('hasTrim') + if has_trim is not None and has_trim not in ('toB', 'toC', True, False): + raise ValueError('hasTrim must be "toB", "toC", true, or false') + charge_rate = trim_data.get('trimRate2') or trim_data.get('chargeRate') + charge_address = trim_data.get('trimAddress2') or trim_data.get('chargeAddress') + is_dual = _is_valid_charge_rate(charge_rate) and _is_valid_charge_address(charge_address) + if (charge_rate is not None or charge_address is not None) and not is_dual: + if (charge_rate is not None) != (charge_address is not None): + raise ValueError('For dual trim, both chargeRate/trimRate2 and chargeAddress/trimAddress2 must be provided') + if not trim_data['trimAddress'].startswith('0x') or len(trim_data['trimAddress']) != 42: + raise ValueError(f'Invalid trimAddress format: {trim_data["trimAddress"]}') + if is_dual and (not charge_address.startswith('0x') or len(charge_address) != 42): + raise ValueError(f'Invalid chargeAddress/trimAddress2 format: {charge_address}') + return True + + +def add_trim_to_calldata(calldata: str, trim_data: dict) -> str: + """Append trim encoding to calldata.""" + try: + calldata_hex = calldata.removeprefix('0x') + if not (trim_data.get('trimRate') and trim_data.get('trimAddress') and trim_data.get('expectAmountOut')): + raise ValueError('Trim data missing required fields: trimRate, trimAddress, expectAmountOut') + charge_rate = trim_data.get('trimRate2') or trim_data.get('chargeRate') + charge_address = trim_data.get('trimAddress2') or trim_data.get('chargeAddress') + is_dual = _is_valid_charge_rate(charge_rate) and _is_valid_charge_address(charge_address) + has_trim = trim_data.get('hasTrim') if trim_data.get('hasTrim') in ('toB', 'toC') else 'toC' + if is_dual: + calldata_hex += ( + _encode_trim_block(charge_rate, charge_address, '0x' + _TRIM_FLAGS['DUAL']) + + _encode_expect_amount_block(trim_data['expectAmountOut'], '0x' + _TRIM_FLAGS['DUAL'], has_trim) + + _encode_trim_block(trim_data['trimRate'], trim_data['trimAddress'], '0x' + _TRIM_FLAGS['DUAL']) + ) + else: + calldata_hex += ( + _encode_expect_amount_block(trim_data['expectAmountOut'], '0x' + _TRIM_FLAGS['SINGLE'], has_trim) + + _encode_trim_block(trim_data['trimRate'], trim_data['trimAddress'], '0x' + _TRIM_FLAGS['SINGLE']) + ) + return '0x' + calldata_hex + except Exception as e: + raise ValueError(f'Failed to encode trim data: {e}') + + +def add_fee_to_calldata(calldata: str, commission_data: dict = None, trim_data: dict = None) -> str: + """Append commission and/or trim encoding to calldata.""" + result = calldata + if commission_data: + result = add_commission_to_calldata(result, commission_data) + if trim_data: + result = add_trim_to_calldata(result, trim_data) + return result + + +# ============================================================================ +# Main encode function +# ============================================================================ + +def encode_functions(json_data: dict) -> str: + """Encode a DEX Router function call to calldata (selector + ABI-encoded params).""" + if not json_data or not json_data.get('function'): + raise ValueError('Invalid input: missing function information') + func_info = json_data['function'] + func_name = func_info.get('name') + func_selector = func_info.get('selector') + if not func_name or not func_selector: + raise ValueError('Invalid function information: missing name or selector') + if func_name not in _FUNC_SPECS: + raise ValueError(f'Unsupported function: {func_name}') + spec = _FUNC_SPECS[func_name] + params = _PREPARE_FN[func_name](json_data) + encoded_params = _abi_encode(spec['types'], params) + return '0x' + _to_bytes(func_selector, 4).hex() + encoded_params.hex() + + +def encode(json_data: dict, commission_data: dict = None, trim_data: dict = None) -> str: + """ + All-in-one encoder: encodes the function call then appends commission and/or trim. + + Args: + json_data: Function JSON (must include 'function.name' and 'function.selector'). + commission_data: Commission dict, or None to skip. + trim_data: Trim dict, or None to skip. + + Returns: + 0x-prefixed calldata hex string. + """ + calldata = encode_functions(json_data) + return add_fee_to_calldata(calldata, commission_data, trim_data) + + +# ============================================================================ +# CLI entry point +# ============================================================================ + +if __name__ == '__main__': + import sys + import json + + def _load(s: str): + try: + with open(s) as f: + return json.load(f) + except (FileNotFoundError, IsADirectoryError): + return json.loads(s) + + def _usage(): + print( + 'Usage: python encode.py [commission_json] [trim_json]\n' + '\n' + ' function_json — file path or inline JSON with function + params\n' + ' commission_json — (optional) file path or inline JSON for commission\n' + ' trim_json — (optional) file path or inline JSON for trim', + file=sys.stderr, + ) + sys.exit(1) + + if len(sys.argv) < 2: + _usage() + + func_data = _load(sys.argv[1]) + commission_data = _load(sys.argv[2]) if len(sys.argv) > 2 else None + trim_data = _load(sys.argv[3]) if len(sys.argv) > 3 else None + + print(encode(func_data, commission_data, trim_data))