diff --git a/.github/workflows/typedb.yml b/.github/workflows/typedb.yml new file mode 100644 index 0000000..b4fa7e0 --- /dev/null +++ b/.github/workflows/typedb.yml @@ -0,0 +1,44 @@ +name: LocalStack TypeDB Extension Tests + +on: + pull_request: + workflow_dispatch: + +env: + LOCALSTACK_DISABLE_EVENTS: "1" + LOCALSTACK_AUTH_TOKEN: ${{ secrets.LOCALSTACK_AUTH_TOKEN }} + +jobs: + integration-tests: + name: Run Integration Tests + runs-on: ubuntu-latest + timeout-minutes: 10 + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup LocalStack and extension + run: | + cd typedb + + docker pull localstack/localstack-pro & + docker pull typedb/typedb & + pip install localstack + + make install + make dist + localstack extensions -v install file://$(ls ./dist/localstack_extension_typedb-*.tar.gz) + + DEBUG=1 localstack start -d + localstack wait + + - name: Run integration tests + run: | + cd typedb + make test + + - name: Print logs + if: always() + run: | + localstack logs + localstack stop diff --git a/README.md b/README.md index f3c2a49..1ffe6fc 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,7 @@ $ localstack extensions install "git+https://github.com/localstack/localstack-ex ## Official LocalStack Extensions Here is the current list of extensions developed by the LocalStack team and their support status. -You can install the respective extension by calling `localstack install `. +You can install the respective extension by calling `localstack extensions install `. | Extension | Install name | Version | Support status | |----------------------------------------------------------------------------------------------------| ------------ |---------| -------------- | @@ -75,6 +75,7 @@ You can install the respective extension by calling `localstack install bool: + # determine if this is a gRPC request targeting TypeDB + content_type = headers.get("content-type") or "" + req_path = headers.get(":path") or "" + is_typedb_grpc_request = ( + "grpc" in content_type and "/typedb.protocol.TypeDB" in req_path + ) + return is_typedb_grpc_request + + def request_to_port_router(self, request: Request) -> int: + # TODO add REST API / gRPC routing based on request + return 1729 diff --git a/typedb/localstack_typedb/utils/__init__.py b/typedb/localstack_typedb/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/typedb/localstack_typedb/utils/docker.py b/typedb/localstack_typedb/utils/docker.py new file mode 100644 index 0000000..2c8b03c --- /dev/null +++ b/typedb/localstack_typedb/utils/docker.py @@ -0,0 +1,192 @@ +import re +import logging +from abc import abstractmethod +from functools import cache +from typing import Callable +import requests + +from localstack import config +from localstack.config import is_env_true +from localstack_typedb.utils.h2_proxy import ( + apply_http2_patches_for_grpc_support, + ProxyRequestMatcher, +) +from localstack.utils.docker_utils import DOCKER_CLIENT +from localstack.extensions.api import Extension, http +from localstack.http import Request +from localstack.utils.container_utils.container_client import PortMappings +from localstack.utils.net import get_addressable_container_host +from localstack.utils.sync import retry +from rolo import route +from rolo.proxy import Proxy +from rolo.routing import RuleAdapter, WithHost +from werkzeug.datastructures import Headers + +LOG = logging.getLogger(__name__) +logging.getLogger("localstack_typedb").setLevel( + logging.DEBUG if config.DEBUG else logging.INFO +) +logging.basicConfig() + + +class ProxiedDockerContainerExtension(Extension): + """ + Utility class to create a LocalStack Extension which runs a Docker container that exposes a service + on one or more ports, with requests being proxied to that container through the LocalStack gateway. + + Requests may potentially use HTTP2 with binary content as the protocol (e.g., gRPC over HTTP2). + To ensure proper routing of requests, subclasses can define the `http2_ports`. + """ + + name: str + """Name of this extension, which must be overridden in a subclass.""" + image_name: str + """Docker image name""" + container_ports: list[int] + """List of network ports of the Docker container spun up by the extension""" + host: str | None + """ + Optional host on which to expose the container endpoints. + Can be either a static hostname, or a pattern like `myext.` + """ + path: str | None + """Optional path on which to expose the container endpoints.""" + command: list[str] | None + """Optional command (and flags) to execute in the container.""" + + request_to_port_router: Callable[[Request], int] | None + """Callable that returns the target port for a given request, for routing purposes""" + http2_ports: list[int] | None + """List of ports for which HTTP2 proxy forwarding into the container should be enabled.""" + + def __init__( + self, + image_name: str, + container_ports: list[int], + host: str | None = None, + path: str | None = None, + command: list[str] | None = None, + request_to_port_router: Callable[[Request], int] | None = None, + http2_ports: list[int] | None = None, + ): + self.image_name = image_name + if not container_ports: + raise ArgumentError("container_ports is required") + self.container_ports = container_ports + self.host = host + self.path = path + self.container_name = re.sub( + r"\W", "-", f"ls-ext-{self.name}" + ) + self.command = command + self.request_to_port_router = request_to_port_router + self.http2_ports = http2_ports + self.main_port = self.container_ports[0] + self.container_host = get_addressable_container_host() + + def update_gateway_routes(self, router: http.Router[http.RouteHandler]): + if self.path: + raise NotImplementedError( + "Path-based routing not yet implemented for this extension" + ) + # note: for simplicity, starting the external container at startup - could be optimized over time ... + self.start_container() + # add resource for HTTP/1.1 requests + resource = RuleAdapter(ProxyResource(self.container_host, self.main_port)) + if self.host: + resource = WithHost(self.host, [resource]) + router.add(resource) + + # apply patches to serve HTTP/2 requests + for port in self.http2_ports or []: + apply_http2_patches_for_grpc_support( + self.container_host, port, self.should_proxy_request + ) + + @abstractmethod + def should_proxy_request(self, headers: Headers) -> bool: + """Define whether a request should be proxied, based on request headers.""" + + def on_platform_shutdown(self): + self._remove_container() + + @cache + def start_container(self) -> None: + LOG.debug("Starting extension container %s", self.container_name) + + port_mapping = PortMappings() + for port in self.container_ports: + port_mapping.add(port) + + kwargs = {} + if self.command: + kwargs["command"] = self.command + + try: + DOCKER_CLIENT.run_container( + self.image_name, + detach=True, + remove=True, + name=self.container_name, + ports=port_mapping, + **kwargs, + ) + except Exception as e: + LOG.debug("Failed to start container %s: %s", self.container_name, e) + # allow running TypeDB in a local server in dev mode, if TYPEDB_DEV_MODE is enabled + if not is_env_true("TYPEDB_DEV_MODE"): + raise + + def _ping_endpoint(): + # TODO: allow defining a custom healthcheck endpoint ... + response = requests.get(f"http://{self.container_host}:{self.main_port}/") + assert response.ok + + try: + retry(_ping_endpoint, retries=40, sleep=1) + except Exception as e: + LOG.info("Failed to connect to container %s: %s", self.container_name, e) + self._remove_container() + raise + + LOG.debug("Successfully started extension container %s", self.container_name) + + def _remove_container(self): + LOG.debug("Stopping extension container %s", self.container_name) + DOCKER_CLIENT.remove_container( + self.container_name, force=True, check_existence=False + ) + + +class ProxyResource: + """ + Simple proxy resource that forwards incoming requests from the + LocalStack Gateway to the target Docker container. + """ + + host: str + port: int + + def __init__(self, host: str, port: int): + self.host = host + self.port = port + + @route("/") + def index(self, request: Request, path: str, *args, **kwargs): + return self._proxy_request(request, forward_path=f"/{path}") + + def _proxy_request(self, request: Request, forward_path: str, *args, **kwargs): + base_url = f"http://{self.host}:{self.port}" + proxy = Proxy(forward_base_url=base_url) + + # update content length (may have changed due to content compression) + if request.method not in ("GET", "OPTIONS"): + request.headers["Content-Length"] = str(len(request.data)) + + # make sure we're forwarding the correct Host header + request.headers["Host"] = f"localhost:{self.port}" + + # forward the request to the target + result = proxy.forward(request, forward_path=forward_path) + + return result diff --git a/typedb/localstack_typedb/utils/h2_proxy.py b/typedb/localstack_typedb/utils/h2_proxy.py new file mode 100644 index 0000000..ee533c1 --- /dev/null +++ b/typedb/localstack_typedb/utils/h2_proxy.py @@ -0,0 +1,144 @@ +import logging +import socket +from typing import Iterable, Callable + +from h2.frame_buffer import FrameBuffer +from hpack import Decoder +from hyperframe.frame import HeadersFrame, Frame +from twisted.internet import reactor + +from localstack.utils.patch import patch +from twisted.web._http2 import H2Connection +from werkzeug.datastructures import Headers + +LOG = logging.getLogger(__name__) + + +ProxyRequestMatcher = Callable[[Headers], bool] + +class TcpForwarder: + """Simple helper class for bidirectional forwarding of TCP traffic.""" + + buffer_size: int = 1024 + """Data buffer size for receiving data from upstream socket.""" + + def __init__(self, port: int, host: str = "localhost"): + self.port = port + self.host = host + self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._socket.connect((self.host, self.port)) + + def receive_loop(self, callback): + while data := self._socket.recv(self.buffer_size): + callback(data) + + def send(self, data): + self._socket.sendall(data) + + def close(self): + LOG.debug(f"Closing connection to upstream HTTP2 server on port {self.port}") + try: + self._socket.shutdown(socket.SHUT_RDWR) + self._socket.close() + except Exception: + # swallow exceptions here (e.g., "bad file descriptor") + pass + + +patched_connection = False + +def apply_http2_patches_for_grpc_support( + target_host: str, target_port: int, should_proxy_request: ProxyRequestMatcher +): + """ + Apply some patches to proxy incoming gRPC requests and forward them to a target port. + Note: this is a very brute-force approach and needs to be fixed/enhanced over time! + """ + LOG.debug(f"Enabling proxying to backend {target_host}:{target_port}") + global patched_connection + assert not patched_connection, "It is not safe to patch H2Connection twice with this function" + patched_connection = True + + class ForwardingBuffer: + """ + A buffer atop the HTTP2 client connection, that will hold + data until the ProxyRequestMatcher tells us whether to send it + to the backend, or leave it to the default handler. + """ + def __init__(self, http_response_stream): + self.http_response_stream = http_response_stream + LOG.debug(f"Starting TCP forwarder to port {target_port} for new HTTP2 connection") + self.backend = TcpForwarder(target_port, host=target_host) + self.buffer = [] + self.proxying = False + reactor.getThreadPool().callInThread(self.backend.receive_loop, self.received_from_backend) + + def received_from_backend(self, data): + LOG.debug(f"Received {len(data)} bytes from backend") + self.http_response_stream.write(data) + + def received_from_http2_client(self, data, default_handler): + if self.proxying: + assert not self.buffer + # Keep sending data to the backend for the lifetime of this connection + self.backend.send(data) + else: + self.buffer.append(data) + if headers := get_headers_from_data_stream(self.buffer): + self.proxying = should_proxy_request(headers) + # Now we know what to do with the buffer + buffered_data = b"".join(self.buffer) + self.buffer = [] + if self.proxying: + LOG.debug(f"Forwarding {len(buffered_data)} bytes to backend") + self.backend.send(buffered_data) + else: + return default_handler(buffered_data) + + def close(self): + self.backend.close() + + @patch(H2Connection.connectionMade) + def _connectionMade(fn, self, *args, **kwargs): + self._ls_forwarding_buffer = ForwardingBuffer(self.transport) + + @patch(H2Connection.dataReceived) + def _dataReceived(fn, self, data, *args, **kwargs): + self._ls_forwarding_buffer.received_from_http2_client(data, lambda d: fn(d, *args, **kwargs)) + + @patch(H2Connection.connectionLost) + def connectionLost(fn, self, *args, **kwargs): + self._ls_forwarding_buffer.close() + + +def get_headers_from_data_stream(data_list: Iterable[bytes]) -> Headers: + """Get headers from a data stream (list of bytes data), if any headers are contained.""" + stream = b"".join(data_list) + return get_headers_from_frames(get_frames_from_http2_stream(stream)) + + +def get_headers_from_frames(frames: Iterable[Frame]) -> Headers: + """Parse the given list of HTTP2 frames and return a dict of headers, if any""" + result = {} + decoder = Decoder() + for frame in frames: + if isinstance(frame, HeadersFrame): + try: + headers = decoder.decode(frame.data) + result.update(dict(headers)) + except Exception: + pass + return Headers(result) + + +def get_frames_from_http2_stream(data: bytes) -> Iterable[Frame]: + """Parse the data from an HTTP2 stream into a list of frames""" + frames = [] + buffer = FrameBuffer(server=True) + buffer.max_frame_size = 16384 + buffer.add_data(data) + try: + for frame in buffer: + yield frame + except Exception: + pass diff --git a/typedb/pyproject.toml b/typedb/pyproject.toml new file mode 100644 index 0000000..0141906 --- /dev/null +++ b/typedb/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["setuptools", "wheel", "plux>=1.3.1"] +build-backend = "setuptools.build_meta" + +[project] +name = "localstack-extension-typedb" +version = "0.1.0" +description = "LocalStack Extension: TypeDB on LocalStack" +readme = {file = "README.md", content-type = "text/markdown; charset=UTF-8"} +requires-python = ">=3.9" +authors = [ + { name = "LocalStack + TypeDB team"} +] +keywords = ["LocalStack", "TypeDB"] +classifiers = [] +dependencies = [ + "httpx", + "h2", + "priority", +] + +[project.urls] +Homepage = "https://github.com/localstack/localstack-extensions" + +[project.optional-dependencies] +dev = [ + "boto3", + "build", + "jsonpatch", + "localstack", + "pytest", + "rolo", + "ruff", + "twisted", + "typedb-driver", +] + +[project.entry-points."localstack.extensions"] +localstack_typedb = "localstack_typedb.extension:TypeDbExtension" diff --git a/typedb/tests/test_extension.py b/typedb/tests/test_extension.py new file mode 100644 index 0000000..4bdd47f --- /dev/null +++ b/typedb/tests/test_extension.py @@ -0,0 +1,84 @@ +import requests +from localstack.utils.strings import short_uid +from localstack_typedb.utils.h2_proxy import get_frames_from_http2_stream, get_headers_from_frames +from typedb.driver import TypeDB, Credentials, DriverOptions, TransactionType + + +def test_connect_to_db_via_http_api(): + host = "typedb.localhost.localstack.cloud:4566" + + # get auth token + response = requests.post( + f"http://{host}/v1/signin", json={"username": "admin", "password": "password"} + ) + assert response.ok + token = response.json()["token"] + + # create database + db_name = f"db{short_uid()}" + response = requests.post( + f"http://{host}/v1/databases/{db_name}", + json={}, + headers={"Authorization": f"bearer {token}"}, + ) + assert response.ok + + # list databases + response = requests.get( + f"http://{host}/v1/databases", headers={"Authorization": f"bearer {token}"} + ) + assert response.ok + databases = [db["name"] for db in response.json()["databases"]] + assert db_name in databases + + # clean up + response = requests.delete( + f"http://{host}/v1/databases/{db_name}", + headers={"Authorization": f"bearer {token}"}, + ) + assert response.ok + + +def test_connect_to_db_via_grpc_endpoint(): + db_name = "access-management-db" + server_host = "typedb.localhost.localstack.cloud:4566" + + driver_cfg = TypeDB.driver( + server_host, + Credentials("admin", "password"), + DriverOptions(is_tls_enabled=False), + ) + with driver_cfg as driver: + if driver.databases.contains(db_name): + driver.databases.get(db_name).delete() + driver.databases.create(db_name) + + with driver.transaction(db_name, TransactionType.SCHEMA) as tx: + tx.query("define entity person;").resolve() + tx.query("define attribute name, value string; person owns name;").resolve() + tx.commit() + + with driver.transaction(db_name, TransactionType.WRITE) as tx: + tx.query("insert $p isa person, has name 'Alice';").resolve() + tx.query("insert $p isa person, has name 'Bob';").resolve() + tx.commit() + with driver.transaction(db_name, TransactionType.READ) as tx: + results = tx.query( + 'match $p isa person; fetch {"name": $p.name};' + ).resolve() + for json in results: + print(json) + + +def test_get_frames_from_http2_stream(): + # note: the data below is a dump taken from a browser request made against the emulator + data = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n\x00\x00\x18\x04\x00\x00\x00\x00\x00\x00\x01\x00\x01\x00\x00\x00\x02\x00\x00\x00\x00\x00\x04\x00\x02\x00\x00\x00\x05\x00\x00@\x00\x00\x00\x04\x08\x00\x00\x00\x00\x00\x00\xbf\x00\x01" + data += b"\x00\x01V\x01%\x00\x00\x00\x03\x00\x00\x00\x00\x15C\x87\xd5\xaf~MZw\x7f\x05\x8eb*\x0eA\xd0\x84\x8c\x9dX\x9c\xa3\xa13\xffA\x96\xa0\xe4\x1d\x13\x9d\t^\x83\x90t!#'U\xc9A\xed\x92\xe3M\xb8\xe7\x87z\xbe\xd0\x7ff\xa2\x81\xb0\xda\xe0S\xfa\xd02\x1a\xa4\x9d\x13\xfd\xa9\x92\xa4\x96\x854\x0c\x8aj\xdc\xa7\xe2\x81\x02\xe1o\xedK;\xdc\x0bM.\x0f\xedLE'S\xb0 \x04\x00\x08\x02\xa6\x13XYO\xe5\x80\xb4\xd2\xe0S\x83\xf9c\xe7Q\x8b-Kp\xdd\xf4Z\xbe\xfb@\x05\xdbP\x92\x9b\xd9\xab\xfaRB\xcb@\xd2_\xa5#\xb3\xe9OhL\x9f@\x94\x19\x08T!b\x1e\xa4\xd8z\x16\xb0\xbd\xad*\x12\xb5%L\xe7\x93\x83\xc5\x83\x7f@\x95\x19\x08T!b\x1e\xa4\xd8z\x16\xb0\xbd\xad*\x12\xb4\xe5\x1c\x85\xb1\x1f\x89\x1d\xa9\x9c\xf6\x1b\xd8\xd2c\xd5s\x95\x9d)\xad\x17\x18`u\xd6\xbd\x07 \xe8BFN\xab\x92\x83\xdb#\x1f@\x85=\x86\x98\xd5\x7f\x94\x9d)\xad\x17\x18`u\xd6\xbd\x07 \xe8BFN\xab\x92\x83\xdb'@\x8aAH\xb4\xa5I'ZB\xa1?\x84-5\xa7\xd7@\x8aAH\xb4\xa5I'Z\x93\xc8_\x83!\xecG@\x8aAH\xb4\xa5I'Y\x06I\x7f\x86@\xe9*\xc82K@\x86\xae\xc3\x1e\xc3'\xd7\x83\xb6\x06\xbf@\x82I\x7f\x86M\x835\x05\xb1\x1f\x00\x00\x04\x08\x00\x00\x00\x00\x03\x00\xbe\x00\x00" + + frames = get_frames_from_http2_stream(data) + assert frames + headers = get_headers_from_frames(frames) + assert headers + assert headers[":scheme"] == "https" + assert headers[":method"] == "OPTIONS" + assert headers[":path"] == "/_localstack/health"