diff --git a/scrapscript.py b/scrapscript.py index 667e0edf..805f8d1f 100755 --- a/scrapscript.py +++ b/scrapscript.py @@ -2482,6 +2482,72 @@ def flat_command(args: argparse.Namespace) -> None: sys.stdout.buffer.write(serializer.output) +def server_command(args: argparse.Namespace) -> None: + import http.server + import socketserver + import hashlib + + dir = os.path.abspath(args.directory) + if not os.path.isdir(dir): + print(f"Error: {dir} is not a valid directory") + sys.exit(1) + + scraps = {} + for root, _, files in os.walk(dir): + for file in files: + file_path = os.path.join(root, file) + rel_path = os.path.relpath(file_path, dir) + if file.startswith("$"): + logger.debug(f"Skipping {rel_path}") + continue + rel_path_without_ext = os.path.splitext(rel_path)[0] + with open(file_path, "r") as f: + try: + program = parse(tokenize(f.read())) + serializer = Serializer() + serializer.serialize(program) + serialized = bytes(serializer.output) + scraps[rel_path_without_ext] = serialized + logger.debug(f"Loaded {rel_path_without_ext}") + file_hash = hashlib.sha256(serialized).hexdigest() + scraps[f"${file_hash}"] = serialized + logger.debug(f"Loaded {rel_path_without_ext} as ${file_hash}") + except Exception as e: + logger.error(f"Error processing {file_path}: {e}") + + keep_serving = True + + class ScrapHTTPRequestHandler(http.server.SimpleHTTPRequestHandler): + def do_QUIT(self) -> None: + self.send_response(200) + self.end_headers() + self.wfile.write(b"Quitting") + nonlocal keep_serving + keep_serving = False + + def do_GET(self) -> None: + path = self.path.lstrip("/") + scrap = scraps.get(path) + if scrap is not None: + self.send_response(200) + self.send_header("Content-Type", "application/scrap; charset=binary") + self.send_header("Content-Disposition", f'attachment; filename={json.dumps(f"{path}.scrap")}') + self.send_header("Content-Length", str(len(scrap))) + self.end_headers() + self.wfile.write(scrap) + else: + self.send_response(404) + self.send_header("Content-Type", "text/plain") + self.end_headers() + self.wfile.write(b"File not found") + + handler = ScrapHTTPRequestHandler + with socketserver.TCPServer((args.host, args.port), handler) as httpd: + logger.info(f"Serving {dir} at http://{args.host}:{args.port}") + while keep_serving: + httpd.handle_request() + + def main() -> None: parser = argparse.ArgumentParser(prog="scrapscript") subparsers = parser.add_subparsers(dest="command") @@ -2521,6 +2587,16 @@ def main() -> None: flat = subparsers.add_parser("flat") flat.set_defaults(func=flat_command) + yard = subparsers.add_parser("yard") + yard.set_defaults(func=lambda _: yard.print_help()) + yard_subparsers = yard.add_subparsers(dest="yard_command") + + yard_server = yard_subparsers.add_parser("server") + yard_server.set_defaults(func=server_command) + yard_server.add_argument("directory", type=str, nargs="?", default=".", help="Directory to serve") + yard_server.add_argument("--host", type=str, default="127.0.0.1", help="Host to bind to") + yard_server.add_argument("--port", type=int, default=8080, help="Port to listen on") + args = parser.parse_args() if not args.command: args.debug = False diff --git a/scrapscript_tests.py b/scrapscript_tests.py index e1626d0e..ceecf58b 100644 --- a/scrapscript_tests.py +++ b/scrapscript_tests.py @@ -1,6 +1,7 @@ import unittest import re from typing import Optional +import urllib.request # ruff: noqa: F405 # ruff: noqa: F403 @@ -4051,5 +4052,55 @@ def test_pretty_print_variant(self) -> None: self.assertEqual(pretty(obj), "#x (a -> b)") +class ServerCommandTests(unittest.TestCase): + def setUp(self) -> None: + import threading + import time + import os + import socket + import argparse + from scrapscript import server_command + + # Find a random available port + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + self.host, self.port = s.getsockname() + + args = argparse.Namespace( + directory=os.path.join(os.path.dirname(__file__), "examples"), + host=self.host, + port=self.port, + ) + + self.server_thread = threading.Thread(target=server_command, args=(args,)) + self.server_thread.daemon = True + self.server_thread.start() + + # Wait for the server to start + while True: + try: + with socket.create_connection((self.host, self.port), timeout=0.1) as s: + break + except (ConnectionRefusedError, socket.timeout): + time.sleep(0.01) + + def tearDown(self) -> None: + quit_request = urllib.request.Request(f"http://{self.host}:{self.port}/", method="QUIT") + urllib.request.urlopen(quit_request) + + def test_server_serves_scrap_by_path(self) -> None: + response = urllib.request.urlopen(f"http://{self.host}:{self.port}/0_home/factorial") + self.assertEqual(response.status, 200) + + def test_server_serves_scrap_by_hash(self) -> None: + response = urllib.request.urlopen(f"http://{self.host}:{self.port}/$09242a8dfec0ed32eb9ddd5452f0082998712d35306fec2042bad8ac5b6e9580") + self.assertEqual(response.status, 200) + + def test_server_fails_missing_scrap(self) -> None: + with self.assertRaises(urllib.error.HTTPError) as cm: + urllib.request.urlopen(f"http://{self.host}:{self.port}/foo") + self.assertEqual(cm.exception.code, 404) + + if __name__ == "__main__": unittest.main()