Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions sauron/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import socket
import pytest

# Store original getaddrinfo
_original_getaddrinfo = socket.getaddrinfo


def getaddrinfo_ipv4_only(host, port, family=0, type=0, proto=0, flags=0):
"""Force IPv4 only by filtering out IPv6 addresses"""
results = _original_getaddrinfo(host, port, socket.AF_INET, type, proto, flags)
if results:
ip = results[0][4][0]
print(f"Connecting to {host} via IPv4: {ip}")
return results


@pytest.fixture(scope="session", autouse=True)
def force_ipv4():
"""Force all network connections to use IPv4"""
socket.getaddrinfo = getaddrinfo_ipv4_only
yield
socket.getaddrinfo = _original_getaddrinfo
2 changes: 1 addition & 1 deletion sauron/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ description = "A Bitcoin backend plugin relying on Esplora"
readme = "README.md"
requires-python = ">=3.9.2"

dependencies = ["pyln-client>=24.11", "requests[socks]>=2.23.0"]
dependencies = ["pyln-client>=24.11", "portalocker>=3.2,<4"]

[dependency-groups]
dev = [
Expand Down
43 changes: 43 additions & 0 deletions sauron/ratelimit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import time
import json
import os
import portalocker


class GlobalRateLimiter:
def __init__(
self,
rate_per_second,
state_file="/tmp/pytest_api_rate.state",
max_wait_seconds=10,
):
self.interval = 1.0 / rate_per_second
self.state_file = state_file
self.max_wait = max_wait_seconds

if not os.path.exists(self.state_file):
with open(self.state_file, "w") as f:
json.dump({"next_ts": 0.0}, f)

def acquire(self):
start = time.time()

while True:
if time.time() - start > self.max_wait:
raise TimeoutError("Rate limiter wait exceeded")

with portalocker.Lock(self.state_file, timeout=10):
with open(self.state_file, "r+") as f:
state = json.load(f)
now = time.time()

if state["next_ts"] <= now:
state["next_ts"] = now + self.interval
f.seek(0)
json.dump(state, f)
f.truncate()
return

wait = state["next_ts"] - now

time.sleep(wait)
219 changes: 182 additions & 37 deletions sauron/sauron.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,26 @@
# requires-python = ">=3.9.2"
# dependencies = [
# "pyln-client>=24.11",
# "requests[socks]>=2.23.0",
# "portalocker>=3.2,<4",
# ]
# ///

import requests
import sys
import time

from requests.packages.urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter
import json
import socket
import os
import base64

import urllib
import urllib.request
import urllib.error
from art import sauron_eye
from pyln.client import Plugin
import portalocker

from ratelimit import GlobalRateLimiter
from shared_cache import SharedRequestCache


plugin = Plugin(dynamic=False)
Expand All @@ -27,25 +35,157 @@ class SauronError(Exception):
pass


def fetch(url):
original_getaddrinfo = socket.getaddrinfo

# def ipv4_only_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0):
# return original_getaddrinfo(host, port, socket.AF_INET, type, proto, flags)

# @contextmanager
# def force_ipv4():
# socket.getaddrinfo = ipv4_only_getaddrinfo
# try:
# yield
# finally:
# socket.getaddrinfo = original_getaddrinfo

rate_limiter = GlobalRateLimiter(rate_per_second=1, max_wait_seconds=15)
cache = SharedRequestCache(ttl_seconds=30)


def fetch(plugin, url):
"""Fetch this {url}, maybe through a pre-defined proxy."""

# FIXME: Maybe try to be smart and renew circuit to broadcast different
# transactions ? Hint: lightningd will agressively send us the same
# transaction a certain amount of times.
session = requests.session()
session.proxies = plugin.sauron_socks_proxies
retry_strategy = Retry(
backoff_factor=1,
total=10,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["HEAD", "GET", "OPTIONS"],
)
adapter = HTTPAdapter(max_retries=retry_strategy)

session.mount("https://", adapter)
session.mount("http://", adapter)
class SimpleResponse:
def __init__(self, content, status_code, headers):
self.content = content
self.status_code = status_code
self.headers = headers
try:
self.text = content.decode("utf-8")
except:
self.text = str(content)

def json(self):
return json.loads(self.text)

key = cache.make_key(url, body="fetch")
lock_file = f"/tmp/fetch_lock_{key}.lock"

# Fast path
plugin.log(f"Checking cache for {url}", level="debug")
cached = cache.get(key)
if cached:
plugin.log(f"Cache hit for {url}", level="debug")
return SimpleResponse(
base64.b64decode(cached["content_b64"]), cached["status"], cached["headers"]
)

return session.get(url)
# Lock per URL
os.makedirs("/tmp", exist_ok=True)

max_retries = 10
backoff_factor = 1
status_forcelist = [429, 500, 502, 503, 504]

for attempt in range(max_retries + 1):
try:
plugin.log(f"Getting fetch lock for {url}", level="debug")
with portalocker.Lock(lock_file, timeout=20):
# Inside lock, re-check cache
plugin.log(f"Re-checking cache for {url}", level="debug")
cached = cache.get(key)
if cached:
plugin.log(f"Cache hit for {url}", level="debug")
return SimpleResponse(
base64.b64decode(cached["content_b64"]),
cached["status"],
cached["headers"],
)

plugin.log("Waiting for rate limit", level="debug")
rate_limiter.acquire()
plugin.log("Rate limit acquired", level="debug")

start = time.time()
plugin.log(f"Opening URL: {url}", level="debug")

# Resolve the host manually to see what address it's using
# host = urllib.parse.urlparse(url).hostname
# port = urllib.parse.urlparse(url).port or 443
# addr_info = socket.getaddrinfo(host, port)
# for family, type, proto, canonname, sockaddr in addr_info[
# :3
# ]: # Show first few
# plugin.log(
# f"Resolved {host}:{port} -> {sockaddr[0]} (family: {'IPv4' if family == socket.AF_INET else 'IPv6' if family == socket.AF_INET6 else family})",
# level="debug",
# )
with urllib.request.urlopen(url, timeout=3) as response:
elapsed = time.time() - start
plugin.log(f"Request took {elapsed:.3f}s", level="debug")

data = response.read()
status = response.status
headers = dict(response.headers)

result = SimpleResponse(data, status, headers)

cache.set(
key,
{
"status": status,
"headers": headers,
"content_b64": base64.b64encode(result.content).decode(
"ascii"
),
},
)
return result

except portalocker.exceptions.LockException:
plugin.log(f"Timeout waiting for request lock for {url}")
time.sleep(0.5)
continue

except urllib.error.HTTPError as e:
# HTTP error responses (4xx, 5xx)
plugin.log(f"HTTP {e.code} for {url}", level="debug")
data = e.read() if e.fp else b""
headers = dict(e.headers) if e.headers else {}

# Retry on specific status codes
if e.code in status_forcelist and attempt < max_retries:
wait_time = backoff_factor * (2**attempt)
plugin.log(
f"Retrying in {wait_time}s (attempt {attempt + 1}/{max_retries})",
level="debug",
)
time.sleep(wait_time)
continue

# Return error response (don't raise)
return SimpleResponse(data, e.code, headers)

except (urllib.error.URLError, OSError, ConnectionError) as e:
# Network errors (DNS, connection refused, timeout, etc.)
if attempt < max_retries:
wait_time = backoff_factor * (2**attempt)
plugin.log(
f"Network error, retrying in {wait_time}s (attempt {attempt + 1}/{max_retries}): {e}",
level="debug",
)
time.sleep(wait_time)
continue
else:
plugin.log(f"Failed after {max_retries} retries: {e}", level="error")
raise

except Exception as e:
plugin.log(f"Failed: {e}", level="error")
raise


@plugin.init()
Expand Down Expand Up @@ -80,7 +220,7 @@ def getchaininfo(plugin, **kwargs):
"00000008819873e925422c1ff0f99f7cc9bbb232af63a077a480a3633bee1ef6": "signet",
}

genesis_req = fetch(blockhash_url)
genesis_req = fetch(plugin, blockhash_url)
if not genesis_req.status_code == 200:
raise SauronError(
"Endpoint at {} returned {} ({}) when trying to "
Expand All @@ -89,10 +229,10 @@ def getchaininfo(plugin, **kwargs):
)
)

blockcount_req = fetch(blockcount_url)
blockcount_req = fetch(plugin, blockcount_url)
if not blockcount_req.status_code == 200:
raise SauronError(
"Endpoint at {} returned {} ({}) when trying to " "get blockcount.".format(
"Endpoint at {} returned {} ({}) when trying to get blockcount.".format(
blockcount_url, blockcount_req.status_code, blockcount_req.text
)
)
Expand All @@ -113,7 +253,7 @@ def getchaininfo(plugin, **kwargs):
@plugin.method("getrawblockbyheight")
def getrawblock(plugin, height, **kwargs):
blockhash_url = "{}/block-height/{}".format(plugin.api_endpoint, height)
blockhash_req = fetch(blockhash_url)
blockhash_req = fetch(plugin, blockhash_url)
if blockhash_req.status_code != 200:
return {
"blockhash": None,
Expand All @@ -122,7 +262,7 @@ def getrawblock(plugin, height, **kwargs):

block_url = "{}/block/{}/raw".format(plugin.api_endpoint, blockhash_req.text)
while True:
block_req = fetch(block_url)
block_req = fetch(plugin, block_url)
if block_req.status_code != 200:
return {
"blockhash": None,
Expand Down Expand Up @@ -150,35 +290,40 @@ def getrawblock(plugin, height, **kwargs):
def sendrawtx(plugin, tx, **kwargs):
sendtx_url = "{}/tx".format(plugin.api_endpoint)

sendtx_req = requests.post(sendtx_url, data=tx)
if sendtx_req.status_code != 200:
try:
req = urllib.request.Request(
sendtx_url, data=tx.encode() if isinstance(tx, str) else tx, method="POST"
)

with urllib.request.urlopen(req, timeout=10) as _response:
return {
"success": True,
"errmsg": "",
}

except Exception as e:
return {
"success": False,
"errmsg": sendtx_req.text,
"errmsg": str(e),
}

return {
"success": True,
"errmsg": "",
}


@plugin.method("getutxout")
def getutxout(plugin, txid, vout, **kwargs):
gettx_url = "{}/tx/{}".format(plugin.api_endpoint, txid)
status_url = "{}/tx/{}/outspend/{}".format(plugin.api_endpoint, txid, vout)

gettx_req = fetch(gettx_url)
gettx_req = fetch(plugin, gettx_url)
if not gettx_req.status_code == 200:
raise SauronError(
"Endpoint at {} returned {} ({}) when trying to " "get transaction.".format(
"Endpoint at {} returned {} ({}) when trying to get transaction.".format(
gettx_url, gettx_req.status_code, gettx_req.text
)
)
status_req = fetch(status_url)
status_req = fetch(plugin, status_url)
if not status_req.status_code == 200:
raise SauronError(
"Endpoint at {} returned {} ({}) when trying to " "get utxo status.".format(
"Endpoint at {} returned {} ({}) when trying to get utxo status.".format(
status_url, status_req.status_code, status_req.text
)
)
Expand All @@ -200,7 +345,7 @@ def getutxout(plugin, txid, vout, **kwargs):
def estimatefees(plugin, **kwargs):
feerate_url = "{}/fee-estimates".format(plugin.api_endpoint)

feerate_req = fetch(feerate_url)
feerate_req = fetch(plugin, feerate_url)
assert feerate_req.status_code == 200
feerates = feerate_req.json()
if plugin.sauron_network in ["test", "signet"]:
Expand Down
Loading
Loading